From b5605dfecc1f277e03165b374bd6ce81638ccb36 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 May 2016 17:37:01 +0100 Subject: Refactor SyncHandler --- synapse/handlers/sync.py | 984 +++++++++++++++++++++++------------------------ 1 file changed, 484 insertions(+), 500 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9ebfccc8bf..80eccf19ae 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -194,157 +194,7 @@ class SyncHandler(object): Returns: A Deferred SyncResult. """ - if since_token is None or full_state: - return self.full_state_sync(sync_config, since_token) - else: - return self.incremental_sync_with_gap(sync_config, since_token) - - @defer.inlineCallbacks - def full_state_sync(self, sync_config, timeline_since_token): - """Get a sync for a client which is starting without any state. - - If a 'message_since_token' is given, only timeline events which have - happened since that token will be returned. - - Returns: - A Deferred SyncResult. - """ - now_token = yield self.event_sources.get_current_token() - - now_token, ephemeral_by_room = yield self.ephemeral_by_room( - sync_config, now_token - ) - - presence_stream = self.event_sources.sources["presence"] - # TODO (mjark): This looks wrong, shouldn't we be getting the presence - # UP to the present rather than after the present? - pagination_config = PaginationConfig(from_token=now_token) - presence, _ = yield presence_stream.get_pagination_rows( - user=sync_config.user, - pagination_config=pagination_config.get_source_config("presence"), - key=None - ) - - membership_list = ( - Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN - ) - - room_list = yield self.store.get_rooms_for_user_where_membership_is( - user_id=sync_config.user.to_string(), - membership_list=membership_list - ) - - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user( - sync_config.user.to_string() - ) - ) - - account_data['m.push_rules'] = yield self.push_rules_for_user( - sync_config.user - ) - - tags_by_room = yield self.store.get_tags_for_user( - sync_config.user.to_string() - ) - - ignored_users = account_data.get( - "m.ignored_user_list", {} - ).get("ignored_users", {}).keys() - - joined = [] - invited = [] - archived = [] - - user_id = sync_config.user.to_string() - - @defer.inlineCallbacks - def _generate_room_entry(event): - if event.membership == Membership.JOIN: - room_result = yield self.full_state_sync_for_joined_room( - room_id=event.room_id, - sync_config=sync_config, - now_token=now_token, - timeline_since_token=timeline_since_token, - ephemeral_by_room=ephemeral_by_room, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - ) - joined.append(room_result) - elif event.membership == Membership.INVITE: - if event.sender in ignored_users: - return - invite = yield self.store.get_event(event.event_id) - invited.append(InvitedSyncResult( - room_id=event.room_id, - invite=invite, - )) - elif event.membership in (Membership.LEAVE, Membership.BAN): - # Always send down rooms we were banned or kicked from. - if not sync_config.filter_collection.include_leave: - if event.membership == Membership.LEAVE: - if user_id == event.sender: - return - - leave_token = now_token.copy_and_replace( - "room_key", "s%d" % (event.stream_ordering,) - ) - room_result = yield self.full_state_sync_for_archived_room( - sync_config=sync_config, - room_id=event.room_id, - leave_event_id=event.event_id, - leave_token=leave_token, - timeline_since_token=timeline_since_token, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - ) - archived.append(room_result) - - yield concurrently_execute(_generate_room_entry, room_list, 10) - - account_data_for_user = sync_config.filter_collection.filter_account_data( - self.account_data_for_user(account_data) - ) - - presence = sync_config.filter_collection.filter_presence( - presence - ) - - defer.returnValue(SyncResult( - presence=presence, - account_data=account_data_for_user, - joined=joined, - invited=invited, - archived=archived, - next_batch=now_token, - )) - - @defer.inlineCallbacks - def full_state_sync_for_joined_room(self, room_id, sync_config, - now_token, timeline_since_token, - ephemeral_by_room, tags_by_room, - account_data_by_room): - """Sync a room for a client which is starting without any state - Returns: - A Deferred JoinedSyncResult. - """ - - batch = yield self.load_filtered_recents( - room_id, sync_config, now_token, since_token=timeline_since_token - ) - - room_sync = yield self.incremental_sync_with_gap_for_room( - room_id, sync_config, - now_token=now_token, - since_token=timeline_since_token, - ephemeral_by_room=ephemeral_by_room, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - batch=batch, - full_state=True, - ) - - defer.returnValue(room_sync) + return self.generate_sync_result(sync_config, since_token, full_state) @defer.inlineCallbacks def push_rules_for_user(self, user): @@ -365,24 +215,6 @@ class SyncHandler(object): return account_data_events - def account_data_for_room(self, room_id, tags_by_room, account_data_by_room): - account_data_events = [] - tags = tags_by_room.get(room_id) - if tags is not None: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) - - account_data = account_data_by_room.get(room_id, {}) - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - return account_data_events - @defer.inlineCallbacks def ephemeral_by_room(self, sync_config, now_token, since_token=None): """Get the ephemeral events for each room the user is in @@ -445,237 +277,6 @@ class SyncHandler(object): defer.returnValue((now_token, ephemeral_by_room)) - def full_state_sync_for_archived_room(self, room_id, sync_config, - leave_event_id, leave_token, - timeline_since_token, tags_by_room, - account_data_by_room): - """Sync a room for a client which is starting without any state - Returns: - A Deferred ArchivedSyncResult. - """ - - return self.incremental_sync_for_archived_room( - sync_config, room_id, leave_event_id, timeline_since_token, tags_by_room, - account_data_by_room, full_state=True, leave_token=leave_token, - ) - - @defer.inlineCallbacks - def incremental_sync_with_gap(self, sync_config, since_token): - """ Get the incremental delta needed to bring the client up to - date with the server. - Returns: - A Deferred SyncResult. - """ - now_token = yield self.event_sources.get_current_token() - - rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) - room_ids = [room.room_id for room in rooms] - - presence_source = self.event_sources.sources["presence"] - presence, presence_key = yield presence_source.get_new_events( - user=sync_config.user, - from_key=since_token.presence_key, - limit=sync_config.filter_collection.presence_limit(), - room_ids=room_ids, - is_guest=sync_config.is_guest, - ) - now_token = now_token.copy_and_replace("presence_key", presence_key) - - now_token, ephemeral_by_room = yield self.ephemeral_by_room( - sync_config, now_token, since_token - ) - - app_service = yield self.store.get_app_service_by_user_id( - sync_config.user.to_string() - ) - if app_service: - rooms = yield self.store.get_app_service_rooms(app_service) - joined_room_ids = set(r.room_id for r in rooms) - else: - rooms = yield self.store.get_rooms_for_user( - sync_config.user.to_string() - ) - joined_room_ids = set(r.room_id for r in rooms) - - user_id = sync_config.user.to_string() - - timeline_limit = sync_config.filter_collection.timeline_limit() - - tags_by_room = yield self.store.get_updated_tags( - user_id, - since_token.account_data_key, - ) - - account_data, account_data_by_room = ( - yield self.store.get_updated_account_data_for_user( - user_id, - since_token.account_data_key, - ) - ) - - push_rules_changed = yield self.store.have_push_rules_changed_for_user( - user_id, int(since_token.push_rules_key) - ) - - if push_rules_changed: - account_data["m.push_rules"] = yield self.push_rules_for_user( - sync_config.user - ) - - ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( - "m.ignored_user_list", user_id=user_id, - ) - - if ignored_account_data: - ignored_users = ignored_account_data.get("ignored_users", {}).keys() - else: - ignored_users = frozenset() - - # Get a list of membership change events that have happened. - rooms_changed = yield self.store.get_membership_changes_for_user( - user_id, since_token.room_key, now_token.room_key - ) - - mem_change_events_by_room_id = {} - for event in rooms_changed: - mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) - - newly_joined_rooms = [] - archived = [] - invited = [] - for room_id, events in mem_change_events_by_room_id.items(): - non_joins = [e for e in events if e.membership != Membership.JOIN] - has_join = len(non_joins) != len(events) - - # We want to figure out if we joined the room at some point since - # the last sync (even if we have since left). This is to make sure - # we do send down the room, and with full state, where necessary - if room_id in joined_room_ids or has_join: - old_state = yield self.get_state_at(room_id, since_token) - old_mem_ev = old_state.get((EventTypes.Member, user_id), None) - if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: - newly_joined_rooms.append(room_id) - - if room_id in joined_room_ids: - continue - - if not non_joins: - continue - - # Only bother if we're still currently invited - should_invite = non_joins[-1].membership == Membership.INVITE - if should_invite: - if event.sender not in ignored_users: - room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) - if room_sync: - invited.append(room_sync) - - # Always include leave/ban events. Just take the last one. - # TODO: How do we handle ban -> leave in same batch? - leave_events = [ - e for e in non_joins - if e.membership in (Membership.LEAVE, Membership.BAN) - ] - - if leave_events: - leave_event = leave_events[-1] - room_sync = yield self.incremental_sync_for_archived_room( - sync_config, room_id, leave_event.event_id, since_token, - tags_by_room, account_data_by_room, - full_state=room_id in newly_joined_rooms - ) - if room_sync: - archived.append(room_sync) - - # Get all events for rooms we're currently joined to. - room_to_events = yield self.store.get_room_events_stream_for_rooms( - room_ids=joined_room_ids, - from_key=since_token.room_key, - to_key=now_token.room_key, - limit=timeline_limit + 1, - ) - - joined = [] - # We loop through all room ids, even if there are no new events, in case - # there are non room events taht we need to notify about. - for room_id in joined_room_ids: - room_entry = room_to_events.get(room_id, None) - - if room_entry: - events, start_key = room_entry - - prev_batch_token = now_token.copy_and_replace("room_key", start_key) - - newly_joined_room = room_id in newly_joined_rooms - full_state = newly_joined_room - - batch = yield self.load_filtered_recents( - room_id, sync_config, prev_batch_token, - since_token=since_token, - recents=events, - newly_joined_room=newly_joined_room, - ) - else: - batch = TimelineBatch( - events=[], - prev_batch=since_token, - limited=False, - ) - full_state = False - - room_sync = yield self.incremental_sync_with_gap_for_room( - room_id=room_id, - sync_config=sync_config, - since_token=since_token, - now_token=now_token, - ephemeral_by_room=ephemeral_by_room, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - batch=batch, - full_state=full_state, - ) - if room_sync: - joined.append(room_sync) - - # For each newly joined room, we want to send down presence of - # existing users. - presence_handler = self.presence_handler - extra_presence_users = set() - for room_id in newly_joined_rooms: - users = yield self.store.get_users_in_room(event.room_id) - extra_presence_users.update(users) - - # For each new member, send down presence. - for joined_sync in joined: - it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values()) - for event in it: - if event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - extra_presence_users.add(event.state_key) - - states = yield presence_handler.get_states( - [u for u in extra_presence_users if u != user_id], - as_event=True, - ) - presence.extend(states) - - account_data_for_user = sync_config.filter_collection.filter_account_data( - self.account_data_for_user(account_data) - ) - - presence = sync_config.filter_collection.filter_presence( - presence - ) - - defer.returnValue(SyncResult( - presence=presence, - account_data=account_data_for_user, - joined=joined, - invited=invited, - archived=archived, - next_batch=now_token, - )) - @defer.inlineCallbacks def load_filtered_recents(self, room_id, sync_config, now_token, since_token=None, recents=None, newly_joined_room=False): @@ -696,6 +297,10 @@ class SyncHandler(object): else: limited = False + if since_token: + if not now_token.is_after(since_token): + limited = False + if recents is not None: recents = sync_config.filter_collection.filter_room_timeline(recents) recents = yield filter_events_for_client( @@ -749,106 +354,9 @@ class SyncHandler(object): )) @defer.inlineCallbacks - def incremental_sync_with_gap_for_room(self, room_id, sync_config, - since_token, now_token, - ephemeral_by_room, tags_by_room, - account_data_by_room, - batch, full_state=False): - state = yield self.compute_state_delta( - room_id, batch, sync_config, since_token, now_token, - full_state=full_state - ) - - account_data = self.account_data_for_room( - room_id, tags_by_room, account_data_by_room - ) - - account_data = sync_config.filter_collection.filter_room_account_data( - account_data - ) - - ephemeral = sync_config.filter_collection.filter_room_ephemeral( - ephemeral_by_room.get(room_id, []) - ) - - unread_notifications = {} - room_sync = JoinedSyncResult( - room_id=room_id, - timeline=batch, - state=state, - ephemeral=ephemeral, - account_data=account_data, - unread_notifications=unread_notifications, - ) - - if room_sync: - notifs = yield self.unread_notifs_for_room_id( - room_id, sync_config - ) - - if notifs is not None: - unread_notifications["notification_count"] = notifs["notify_count"] - unread_notifications["highlight_count"] = notifs["highlight_count"] - - logger.debug("Room sync: %r", room_sync) - - defer.returnValue(room_sync) - - @defer.inlineCallbacks - def incremental_sync_for_archived_room(self, sync_config, room_id, leave_event_id, - since_token, tags_by_room, - account_data_by_room, full_state, - leave_token=None): - """ Get the incremental delta needed to bring the client up to date for - the archived room. - Returns: - A Deferred ArchivedSyncResult - """ - - if not leave_token: - stream_token = yield self.store.get_stream_token_for_event( - leave_event_id - ) - - leave_token = since_token.copy_and_replace("room_key", stream_token) - - if since_token and since_token.is_after(leave_token): - defer.returnValue(None) - - batch = yield self.load_filtered_recents( - room_id, sync_config, leave_token, since_token, - ) - - logger.debug("Recents %r", batch) - - state_events_delta = yield self.compute_state_delta( - room_id, batch, sync_config, since_token, leave_token, - full_state=full_state - ) - - account_data = self.account_data_for_room( - room_id, tags_by_room, account_data_by_room - ) - - account_data = sync_config.filter_collection.filter_room_account_data( - account_data - ) - - room_sync = ArchivedSyncResult( - room_id=room_id, - timeline=batch, - state=state_events_delta, - account_data=account_data, - ) - - logger.debug("Room sync: %r", room_sync) - - defer.returnValue(room_sync) - - @defer.inlineCallbacks - def get_state_after_event(self, event): - """ - Get the room state after the given event + def get_state_after_event(self, event): + """ + Get the room state after the given event Args: event(synapse.events.EventBase): event of interest @@ -1010,6 +518,457 @@ class SyncHandler(object): # count is whatever it was last time. defer.returnValue(None) + @defer.inlineCallbacks + def generate_sync_result(self, sync_config, since_token=None, full_state=False): + now_token = yield self.event_sources.get_current_token() + + sync_result_builer = SyncResultBuilder( + sync_config, full_state, + since_token=since_token, + now_token=now_token, + ) + + account_data_by_room = yield self.generate_sync_entry_for_account_data( + sync_result_builer + ) + + newly_joined_rooms, newly_joined_users = yield self.generate_sync_entry_for_rooms( + sync_result_builer, account_data_by_room + ) + + yield self.generate_sync_entry_for_presence( + sync_result_builer, newly_joined_rooms, newly_joined_users + ) + + defer.returnValue(SyncResult( + presence=sync_result_builer.presence, + account_data=sync_result_builer.account_data, + joined=sync_result_builer.joined, + invited=sync_result_builer.invited, + archived=sync_result_builer.archived, + next_batch=sync_result_builer.now_token, + )) + + @defer.inlineCallbacks + def generate_sync_entry_for_account_data(self, sync_result_builer): + sync_config = sync_result_builer.sync_config + user_id = sync_result_builer.sync_config.user.to_string() + since_token = sync_result_builer.since_token + + if since_token and not sync_result_builer.full_state: + account_data, account_data_by_room = ( + yield self.store.get_updated_account_data_for_user( + user_id, + since_token.account_data_key, + ) + ) + + push_rules_changed = yield self.store.have_push_rules_changed_for_user( + user_id, int(since_token.push_rules_key) + ) + + if push_rules_changed: + account_data["m.push_rules"] = yield self.push_rules_for_user( + sync_config.user + ) + else: + account_data, account_data_by_room = ( + yield self.store.get_account_data_for_user( + sync_config.user.to_string() + ) + ) + + account_data['m.push_rules'] = yield self.push_rules_for_user( + sync_config.user + ) + + account_data_for_user = sync_config.filter_collection.filter_account_data( + self.account_data_for_user(account_data) + ) + + sync_result_builer.account_data = account_data_for_user + + defer.returnValue(account_data_by_room) + + @defer.inlineCallbacks + def generate_sync_entry_for_presence(self, sync_result_builer, newly_joined_rooms, + newly_joined_users): + now_token = sync_result_builer.now_token + sync_config = sync_result_builer.sync_config + user = sync_result_builer.sync_config.user + + presence_source = self.event_sources.sources["presence"] + + since_token = sync_result_builer.since_token + if since_token and not sync_result_builer.full_state: + presence_key = since_token.presence_key + else: + presence_key = None + + presence, presence_key = yield presence_source.get_new_events( + user=user, + from_key=presence_key, + is_guest=sync_config.is_guest, + ) + sync_result_builer.now_token = now_token.copy_and_replace( + "presence_key", presence_key + ) + + extra_users_ids = set(newly_joined_users) + for room_id in newly_joined_rooms: + users = yield self.store.get_users_in_room(room_id) + extra_users_ids.update(users) + extra_users_ids.discard(user.to_string()) + + states = yield self.presence_handler.get_states( + extra_users_ids, + as_event=True, + ) + presence.extend(states) + + presence = sync_config.filter_collection.filter_presence( + presence + ) + + sync_result_builer.presence = presence + + @defer.inlineCallbacks + def generate_sync_entry_for_rooms(self, sync_result_builer, account_data_by_room): + user_id = sync_result_builer.sync_config.user.to_string() + + now_token, ephemeral_by_room = yield self.ephemeral_by_room( + sync_result_builer.sync_config, sync_result_builer.now_token + ) + sync_result_builer.now_token = now_token + + ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( + "m.ignored_user_list", user_id=user_id, + ) + + if ignored_account_data: + ignored_users = ignored_account_data.get("ignored_users", {}).keys() + else: + ignored_users = frozenset() + + if sync_result_builer.since_token: + res = yield self._get_rooms_changed(sync_result_builer, ignored_users) + joined, invited, archived, newly_joined_rooms = res + + tags_by_room = yield self.store.get_updated_tags( + user_id, + sync_result_builer.since_token.account_data_key, + ) + else: + res = yield self._get_all_rooms(sync_result_builer, ignored_users) + joined, invited, archived, newly_joined_rooms = res + + tags_by_room = yield self.store.get_tags_for_user(user_id) + + for room_entry in joined: + yield self._generate_room_entry( + "joined", + sync_result_builer, + ignored_users, + room_entry, + ephemeral=ephemeral_by_room.get(room_entry.room_id, []), + tags=tags_by_room.get(room_entry.room_id), + account_data=account_data_by_room.get(room_entry.room_id, {}), + always_include=sync_result_builer.full_state, + ) + for room_entry in archived: + yield self._generate_room_entry( + "archived", + sync_result_builer, + ignored_users, + room_entry, + ephemeral=ephemeral_by_room.get(room_entry.room_id, []), + tags=tags_by_room.get(room_entry.room_id), + account_data=account_data_by_room.get(room_entry.room_id, {}), + always_include=sync_result_builer.full_state, + ) + + sync_result_builer.invited.extend(invited) + + # Now we want to get any newly joined users + newly_joined_users = set() + for joined_sync in sync_result_builer.joined: + it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values()) + for event in it: + if event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + newly_joined_users.add(event.state_key) + + defer.returnValue((newly_joined_rooms, newly_joined_users)) + + @defer.inlineCallbacks + def _get_rooms_changed(self, sync_result_builer, ignored_users): + user_id = sync_result_builer.sync_config.user.to_string() + since_token = sync_result_builer.since_token + now_token = sync_result_builer.now_token + sync_config = sync_result_builer.sync_config + + assert since_token + + app_service = yield self.store.get_app_service_by_user_id(user_id) + if app_service: + rooms = yield self.store.get_app_service_rooms(app_service) + joined_room_ids = set(r.room_id for r in rooms) + else: + rooms = yield self.store.get_rooms_for_user(user_id) + joined_room_ids = set(r.room_id for r in rooms) + + # Get a list of membership change events that have happened. + rooms_changed = yield self.store.get_membership_changes_for_user( + user_id, since_token.room_key, now_token.room_key + ) + + mem_change_events_by_room_id = {} + for event in rooms_changed: + mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) + + newly_joined_rooms = [] + archived = [] + invited = [] + for room_id, events in mem_change_events_by_room_id.items(): + non_joins = [e for e in events if e.membership != Membership.JOIN] + has_join = len(non_joins) != len(events) + + # We want to figure out if we joined the room at some point since + # the last sync (even if we have since left). This is to make sure + # we do send down the room, and with full state, where necessary + if room_id in joined_room_ids or has_join: + old_state = yield self.get_state_at(room_id, since_token) + old_mem_ev = old_state.get((EventTypes.Member, user_id), None) + if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: + newly_joined_rooms.append(room_id) + + if room_id in joined_room_ids: + continue + + if not non_joins: + continue + + # Only bother if we're still currently invited + should_invite = non_joins[-1].membership == Membership.INVITE + if should_invite: + if event.sender not in ignored_users: + room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) + if room_sync: + invited.append(room_sync) + + # Always include leave/ban events. Just take the last one. + # TODO: How do we handle ban -> leave in same batch? + leave_events = [ + e for e in non_joins + if e.membership in (Membership.LEAVE, Membership.BAN) + ] + + if leave_events: + leave_event = leave_events[-1] + leave_stream_token = yield self.store.get_stream_token_for_event( + leave_event.event_id + ) + leave_token = since_token.copy_and_replace( + "room_key", leave_stream_token + ) + + if since_token and since_token.is_after(leave_token): + continue + + archived.append(RoomSyncResultBuilder( + room_id=room_id, + events=None, + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=since_token, + upto_token=leave_token, + )) + + timeline_limit = sync_config.filter_collection.timeline_limit() + + # Get all events for rooms we're currently joined to. + room_to_events = yield self.store.get_room_events_stream_for_rooms( + room_ids=joined_room_ids, + from_key=since_token.room_key, + to_key=now_token.room_key, + limit=timeline_limit + 1, + ) + + joined = [] + # We loop through all room ids, even if there are no new events, in case + # there are non room events taht we need to notify about. + for room_id in joined_room_ids: + room_entry = room_to_events.get(room_id, None) + + if room_entry: + events, start_key = room_entry + + prev_batch_token = now_token.copy_and_replace("room_key", start_key) + + joined.append(RoomSyncResultBuilder( + room_id=room_id, + events=events, + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=None if room_id in newly_joined_rooms else since_token, + upto_token=prev_batch_token, + )) + else: + joined.append(RoomSyncResultBuilder( + room_id=room_id, + events=[], + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=since_token, + upto_token=since_token, + )) + + defer.returnValue((joined, invited, archived, newly_joined_rooms)) + + @defer.inlineCallbacks + def _get_all_rooms(self, sync_result_builer, ignored_users): + user_id = sync_result_builer.sync_config.user.to_string() + since_token = sync_result_builer.since_token + now_token = sync_result_builer.now_token + sync_config = sync_result_builer.sync_config + + membership_list = ( + Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN + ) + + room_list = yield self.store.get_rooms_for_user_where_membership_is( + user_id=user_id, + membership_list=membership_list + ) + + joined = [] + invited = [] + archived = [] + + for event in room_list: + if event.membership == Membership.JOIN: + joined.append(RoomSyncResultBuilder( + room_id=event.room_id, + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=now_token, + )) + elif event.membership == Membership.INVITE: + if event.sender in ignored_users: + continue + invite = yield self.store.get_event(event.event_id) + invited.append(InvitedSyncResult( + room_id=event.room_id, + invite=invite, + )) + elif event.membership in (Membership.LEAVE, Membership.BAN): + # Always send down rooms we were banned or kicked from. + if not sync_config.filter_collection.include_leave: + if event.membership == Membership.LEAVE: + if user_id == event.sender: + continue + + leave_token = now_token.copy_and_replace( + "room_key", "s%d" % (event.stream_ordering,) + ) + archived.append(RoomSyncResultBuilder( + room_id=event.room_id, + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=leave_token, + )) + + defer.returnValue((joined, invited, archived, [])) + + @defer.inlineCallbacks + def _generate_room_entry(self, room_type, sync_result_builer, ignored_users, + room_builder, ephemeral, tags, account_data, + always_include=False): + since_token = sync_result_builer.since_token + now_token = sync_result_builer.now_token + sync_config = sync_result_builer.sync_config + + room_id = room_builder.room_id + events = room_builder.events + newly_joined = room_builder.newly_joined + full_state = ( + room_builder.full_state + or newly_joined + or sync_result_builer.full_state + ) + since_token = room_builder.since_token + upto_token = room_builder.upto_token + + batch = yield self.load_filtered_recents( + room_id, sync_config, + now_token=upto_token, + since_token=since_token, + recents=events, + newly_joined_room=newly_joined, # FIXME + ) + + account_data_events = [] + if tags is not None: + account_data_events.append({ + "type": "m.tag", + "content": {"tags": tags}, + }) + + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + account_data = sync_config.filter_collection.filter_room_account_data( + account_data_events + ) + + ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral) + + if not (always_include or batch or account_data or ephemeral or full_state): + return + + state = yield self.compute_state_delta( + room_id, batch, sync_config, since_token, now_token, + full_state=full_state + ) + + if room_type == "joined": + unread_notifications = {} + room_sync = JoinedSyncResult( + room_id=room_id, + timeline=batch, + state=state, + ephemeral=ephemeral, + account_data=account_data_events, + unread_notifications=unread_notifications, + ) + + if room_sync or always_include: + notifs = yield self.unread_notifs_for_room_id( + room_id, sync_config + ) + + if notifs is not None: + unread_notifications["notification_count"] = notifs["notify_count"] + unread_notifications["highlight_count"] = notifs["highlight_count"] + + sync_result_builer.joined.append(room_sync) + elif room_type == "archived": + room_sync = ArchivedSyncResult( + room_id=room_id, + timeline=batch, + state=state, + account_data=account_data, + ) + if room_sync or always_include: + sync_result_builer.archived.append(room_sync) + def _action_has_highlight(actions): for action in actions: @@ -1057,3 +1016,28 @@ def _calculate_state(timeline_contains, timeline_start, previous, current): (e.type, e.state_key): e for e in evs } + + +class SyncResultBuilder(object): + def __init__(self, sync_config, full_state, since_token, now_token): + self.sync_config = sync_config + self.full_state = full_state + self.since_token = since_token + self.now_token = now_token + + self.presence = [] + self.account_data = [] + self.joined = [] + self.invited = [] + self.archived = [] + + +class RoomSyncResultBuilder(object): + def __init__(self, room_id, events, newly_joined, full_state, since_token, + upto_token): + self.room_id = room_id + self.events = events + self.newly_joined = newly_joined + self.full_state = full_state + self.since_token = since_token + self.upto_token = upto_token -- cgit 1.5.1 From c0c79ef444ca0f21e8324abc8a813026aaf6cf17 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 May 2016 18:21:27 +0100 Subject: Add back concurrently_execute --- synapse/handlers/sync.py | 34 +++++++++------------------------- 1 file changed, 9 insertions(+), 25 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 80eccf19ae..bc6d6af133 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.streams.config import PaginationConfig from synapse.api.constants import Membership, EventTypes from synapse.util.async import concurrently_execute from synapse.util.logcontext import LoggingContext @@ -478,26 +477,6 @@ class SyncHandler(object): for e in sync_config.filter_collection.filter_room_state(state.values()) }) - def check_joined_room(self, sync_config, state_delta): - """ - Check if the user has just joined the given room (so should - be given the full state) - - Args: - sync_config(synapse.handlers.sync.SyncConfig): - state_delta(dict[(str,str), synapse.events.FrozenEvent]): the - difference in state since the last sync - - Returns: - A deferred Tuple (state_delta, limited) - """ - join_event = state_delta.get(( - EventTypes.Member, sync_config.user.to_string()), None) - if join_event is not None: - if join_event.content["membership"] == Membership.JOIN: - return True - return False - @defer.inlineCallbacks def unread_notifs_for_room_id(self, room_id, sync_config): with Measure(self.clock, "unread_notifs_for_room_id"): @@ -664,8 +643,8 @@ class SyncHandler(object): tags_by_room = yield self.store.get_tags_for_user(user_id) - for room_entry in joined: - yield self._generate_room_entry( + def handle_joined(room_entry): + return self._generate_room_entry( "joined", sync_result_builer, ignored_users, @@ -675,8 +654,11 @@ class SyncHandler(object): account_data=account_data_by_room.get(room_entry.room_id, {}), always_include=sync_result_builer.full_state, ) - for room_entry in archived: - yield self._generate_room_entry( + + yield concurrently_execute(handle_joined, joined, 10) + + def handle_archived(room_entry): + return self._generate_room_entry( "archived", sync_result_builer, ignored_users, @@ -687,6 +669,8 @@ class SyncHandler(object): always_include=sync_result_builer.full_state, ) + yield concurrently_execute(handle_archived, archived, 10) + sync_result_builer.invited.extend(invited) # Now we want to get any newly joined users -- cgit 1.5.1 From 6fe04ffef2fd44a3af61fe262266b25158c46db7 Mon Sep 17 00:00:00 2001 From: Negi Fazeli Date: Mon, 23 May 2016 11:14:11 +0200 Subject: Fix set profile error with Requester. Replace flush_user with delete access token due to function removal Add a new test case for if the user is already registered --- synapse/handlers/register.py | 9 +++++---- tests/handlers/test_register.py | 34 +++++++++++++++++++++++++--------- 2 files changed, 30 insertions(+), 13 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 5883b9111e..16f33f8371 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -16,7 +16,7 @@ """Contains functions for registering clients.""" from twisted.internet import defer -from synapse.types import UserID +from synapse.types import UserID, Requester from synapse.api.errors import ( AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError ) @@ -360,7 +360,8 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def get_or_create_user(self, localpart, displayname, duration_seconds): - """Creates a new user or returns an access token for an existing one + """Creates a new user if the user does not exist, + else revokes all previous access tokens and generates a new one. Args: localpart : The local part of the user ID to register. If None, @@ -399,14 +400,14 @@ class RegistrationHandler(BaseHandler): yield registered_user(self.distributor, user) else: - yield self.store.flush_user(user_id=user_id) + yield self.store.user_delete_access_tokens(user_id=user_id) yield self.store.add_access_token_to_user(user_id=user_id, token=token) if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) profile_handler = self.hs.get_handlers().profile_handler yield profile_handler.set_displayname( - user, user, displayname + user, Requester(user, token, False), displayname ) defer.returnValue((user_id, token)) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 8b7be96bd9..9d5c653b45 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,6 +17,7 @@ from twisted.internet import defer from .. import unittest from synapse.handlers.register import RegistrationHandler +from synapse.types import UserID from tests.utils import setup_test_homeserver @@ -36,25 +37,21 @@ class RegistrationTestCase(unittest.TestCase): self.mock_distributor = Mock() self.mock_distributor.declare("registered_user") self.mock_captcha_client = Mock() - hs = yield setup_test_homeserver( + self.hs = yield setup_test_homeserver( handlers=None, http_client=None, expire_access_token=True) - hs.handlers = RegistrationHandlers(hs) - self.handler = hs.get_handlers().registration_handler - hs.get_handlers().profile_handler = Mock() + self.hs.handlers = RegistrationHandlers(self.hs) + self.handler = self.hs.get_handlers().registration_handler + self.hs.get_handlers().profile_handler = Mock() self.mock_handler = Mock(spec=[ "generate_short_term_login_token", ]) - hs.get_handlers().auth_handler = self.mock_handler + self.hs.get_handlers().auth_handler = self.mock_handler @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): - """ - Returns: - The user doess not exist in this case so it will register and log it in - """ duration_ms = 200 local_part = "someone" display_name = "someone" @@ -65,3 +62,22 @@ class RegistrationTestCase(unittest.TestCase): local_part, display_name, duration_ms) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') + + @defer.inlineCallbacks + def test_if_user_exists(self): + store = self.hs.get_datastore() + frank = UserID.from_string("@frank:test") + yield store.register( + user_id=frank.to_string(), + token="jkv;g498752-43gj['eamb!-5", + password_hash=None) + duration_ms = 200 + local_part = "frank" + display_name = "Frank" + user_id = "@frank:test" + mock_token = self.mock_handler.generate_short_term_login_token + mock_token.return_value = 'secret' + result_user_id, result_token = yield self.handler.get_or_create_user( + local_part, display_name, duration_ms) + self.assertEquals(result_user_id, user_id) + self.assertEquals(result_token, 'secret') -- cgit 1.5.1 From 137e6a45577d5850ef6936670791af12c2fe74d9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2016 09:43:35 +0100 Subject: Shuffle things room --- synapse/handlers/sync.py | 70 +++++++++++++++++++++++------------------------- 1 file changed, 33 insertions(+), 37 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index bc6d6af133..6b7c6a436e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -499,6 +499,10 @@ class SyncHandler(object): @defer.inlineCallbacks def generate_sync_result(self, sync_config, since_token=None, full_state=False): + # NB: The now_token gets changed by some of the generate_sync_* methods, + # this is due to some of the underlying streams not supporting the ability + # to query up to a given point. + # Always use the `now_token` in `SyncResultBuilder` now_token = yield self.event_sources.get_current_token() sync_result_builer = SyncResultBuilder( @@ -511,9 +515,10 @@ class SyncHandler(object): sync_result_builer ) - newly_joined_rooms, newly_joined_users = yield self.generate_sync_entry_for_rooms( + res = yield self.generate_sync_entry_for_rooms( sync_result_builer, account_data_by_room ) + newly_joined_rooms, newly_joined_users = res yield self.generate_sync_entry_for_presence( sync_result_builer, newly_joined_rooms, newly_joined_users @@ -631,7 +636,7 @@ class SyncHandler(object): if sync_result_builer.since_token: res = yield self._get_rooms_changed(sync_result_builer, ignored_users) - joined, invited, archived, newly_joined_rooms = res + room_entries, invited, newly_joined_rooms = res tags_by_room = yield self.store.get_updated_tags( user_id, @@ -639,13 +644,12 @@ class SyncHandler(object): ) else: res = yield self._get_all_rooms(sync_result_builer, ignored_users) - joined, invited, archived, newly_joined_rooms = res + room_entries, invited, newly_joined_rooms = res tags_by_room = yield self.store.get_tags_for_user(user_id) - def handle_joined(room_entry): + def handle_room_entries(room_entry): return self._generate_room_entry( - "joined", sync_result_builer, ignored_users, room_entry, @@ -655,21 +659,7 @@ class SyncHandler(object): always_include=sync_result_builer.full_state, ) - yield concurrently_execute(handle_joined, joined, 10) - - def handle_archived(room_entry): - return self._generate_room_entry( - "archived", - sync_result_builer, - ignored_users, - room_entry, - ephemeral=ephemeral_by_room.get(room_entry.room_id, []), - tags=tags_by_room.get(room_entry.room_id), - account_data=account_data_by_room.get(room_entry.room_id, {}), - always_include=sync_result_builer.full_state, - ) - - yield concurrently_execute(handle_archived, archived, 10) + yield concurrently_execute(handle_room_entries, room_entries, 10) sync_result_builer.invited.extend(invited) @@ -711,7 +701,7 @@ class SyncHandler(object): mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) newly_joined_rooms = [] - archived = [] + room_entries = [] invited = [] for room_id, events in mem_change_events_by_room_id.items(): non_joins = [e for e in events if e.membership != Membership.JOIN] @@ -759,8 +749,9 @@ class SyncHandler(object): if since_token and since_token.is_after(leave_token): continue - archived.append(RoomSyncResultBuilder( + room_entries.append(RoomSyncResultBuilder( room_id=room_id, + rtype="archived", events=None, newly_joined=room_id in newly_joined_rooms, full_state=False, @@ -778,7 +769,6 @@ class SyncHandler(object): limit=timeline_limit + 1, ) - joined = [] # We loop through all room ids, even if there are no new events, in case # there are non room events taht we need to notify about. for room_id in joined_room_ids: @@ -789,8 +779,9 @@ class SyncHandler(object): prev_batch_token = now_token.copy_and_replace("room_key", start_key) - joined.append(RoomSyncResultBuilder( + room_entries.append(RoomSyncResultBuilder( room_id=room_id, + rtype="joined", events=events, newly_joined=room_id in newly_joined_rooms, full_state=False, @@ -798,8 +789,9 @@ class SyncHandler(object): upto_token=prev_batch_token, )) else: - joined.append(RoomSyncResultBuilder( + room_entries.append(RoomSyncResultBuilder( room_id=room_id, + rtype="joined", events=[], newly_joined=room_id in newly_joined_rooms, full_state=False, @@ -807,7 +799,7 @@ class SyncHandler(object): upto_token=since_token, )) - defer.returnValue((joined, invited, archived, newly_joined_rooms)) + defer.returnValue((room_entries, invited, newly_joined_rooms)) @defer.inlineCallbacks def _get_all_rooms(self, sync_result_builer, ignored_users): @@ -825,14 +817,14 @@ class SyncHandler(object): membership_list=membership_list ) - joined = [] + room_entries = [] invited = [] - archived = [] for event in room_list: if event.membership == Membership.JOIN: - joined.append(RoomSyncResultBuilder( + room_entries.append(RoomSyncResultBuilder( room_id=event.room_id, + rtype="joined", events=None, newly_joined=False, full_state=True, @@ -857,8 +849,9 @@ class SyncHandler(object): leave_token = now_token.copy_and_replace( "room_key", "s%d" % (event.stream_ordering,) ) - archived.append(RoomSyncResultBuilder( + room_entries.append(RoomSyncResultBuilder( room_id=event.room_id, + rtype="archived", events=None, newly_joined=False, full_state=True, @@ -866,10 +859,10 @@ class SyncHandler(object): upto_token=leave_token, )) - defer.returnValue((joined, invited, archived, [])) + defer.returnValue((room_entries, invited, [])) @defer.inlineCallbacks - def _generate_room_entry(self, room_type, sync_result_builer, ignored_users, + def _generate_room_entry(self, sync_result_builer, ignored_users, room_builder, ephemeral, tags, account_data, always_include=False): since_token = sync_result_builer.since_token @@ -892,7 +885,7 @@ class SyncHandler(object): now_token=upto_token, since_token=since_token, recents=events, - newly_joined_room=newly_joined, # FIXME + newly_joined_room=newly_joined, ) account_data_events = [] @@ -922,7 +915,7 @@ class SyncHandler(object): full_state=full_state ) - if room_type == "joined": + if room_builder.rtype == "joined": unread_notifications = {} room_sync = JoinedSyncResult( room_id=room_id, @@ -943,7 +936,7 @@ class SyncHandler(object): unread_notifications["highlight_count"] = notifs["highlight_count"] sync_result_builer.joined.append(room_sync) - elif room_type == "archived": + elif room_builder.rtype == "archived": room_sync = ArchivedSyncResult( room_id=room_id, timeline=batch, @@ -952,6 +945,8 @@ class SyncHandler(object): ) if room_sync or always_include: sync_result_builer.archived.append(room_sync) + else: + raise Exception("Unrecognized rtype: %r", room_builder.rtype) def _action_has_highlight(actions): @@ -1017,9 +1012,10 @@ class SyncResultBuilder(object): class RoomSyncResultBuilder(object): - def __init__(self, room_id, events, newly_joined, full_state, since_token, - upto_token): + def __init__(self, room_id, rtype, events, newly_joined, full_state, + since_token, upto_token): self.room_id = room_id + self.rtype = rtype self.events = events self.newly_joined = newly_joined self.full_state = full_state -- cgit 1.5.1 From 84f94e4cbbd8c79d867503159b564aaf233e46d5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2016 10:14:53 +0100 Subject: Add comments --- synapse/handlers/sync.py | 112 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 105 insertions(+), 7 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6b7c6a436e..271dd4c147 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -499,6 +499,17 @@ class SyncHandler(object): @defer.inlineCallbacks def generate_sync_result(self, sync_config, since_token=None, full_state=False): + """Generates a sync result. + + Args: + sync_config (SyncConfig) + since_token (StreamToken) + full_state (bool) + + Returns: + Deferred(SyncResult) + """ + # NB: The now_token gets changed by some of the generate_sync_* methods, # this is due to some of the underlying streams not supporting the ability # to query up to a given point. @@ -511,16 +522,16 @@ class SyncHandler(object): now_token=now_token, ) - account_data_by_room = yield self.generate_sync_entry_for_account_data( + account_data_by_room = yield self._generate_sync_entry_for_account_data( sync_result_builer ) - res = yield self.generate_sync_entry_for_rooms( + res = yield self._generate_sync_entry_for_rooms( sync_result_builer, account_data_by_room ) newly_joined_rooms, newly_joined_users = res - yield self.generate_sync_entry_for_presence( + yield self._generate_sync_entry_for_presence( sync_result_builer, newly_joined_rooms, newly_joined_users ) @@ -534,7 +545,16 @@ class SyncHandler(object): )) @defer.inlineCallbacks - def generate_sync_entry_for_account_data(self, sync_result_builer): + def _generate_sync_entry_for_account_data(self, sync_result_builer): + """Generates the account data portion of the sync response. Populates + `sync_result_builer` with the result. + + Args: + sync_result_builer(SyncResultBuilder) + + Returns: + Deferred(dict): A dictionary containing the per room account data. + """ sync_config = sync_result_builer.sync_config user_id = sync_result_builer.sync_config.user.to_string() since_token = sync_result_builer.since_token @@ -575,8 +595,18 @@ class SyncHandler(object): defer.returnValue(account_data_by_room) @defer.inlineCallbacks - def generate_sync_entry_for_presence(self, sync_result_builer, newly_joined_rooms, - newly_joined_users): + def _generate_sync_entry_for_presence(self, sync_result_builer, newly_joined_rooms, + newly_joined_users): + """Generates the presence portion of the sync response. Populates the + `sync_result_builer` with the result. + + Args: + sync_result_builer(SyncResultBuilder) + newly_joined_rooms(list): List of rooms that the user has joined + since the last sync (or empty if an initial sync) + newly_joined_users(list): List of users that have joined rooms + since the last sync (or empty if an initial sync) + """ now_token = sync_result_builer.now_token sync_config = sync_result_builer.sync_config user = sync_result_builer.sync_config.user @@ -617,7 +647,18 @@ class SyncHandler(object): sync_result_builer.presence = presence @defer.inlineCallbacks - def generate_sync_entry_for_rooms(self, sync_result_builer, account_data_by_room): + def _generate_sync_entry_for_rooms(self, sync_result_builer, account_data_by_room): + """Generates the rooms portion of the sync response. Populates the + `sync_result_builer` with the result. + + Args: + sync_result_builer(SyncResultBuilder) + account_data_by_room(dict): Dictionary of per room account data + + Returns: + Deferred(tuple): Returns a 2-tuple of + `(newly_joined_rooms, newly_joined_users)` + """ user_id = sync_result_builer.sync_config.user.to_string() now_token, ephemeral_by_room = yield self.ephemeral_by_room( @@ -676,6 +717,16 @@ class SyncHandler(object): @defer.inlineCallbacks def _get_rooms_changed(self, sync_result_builer, ignored_users): + """Gets the the changes that have happened since the last sync. + + Args: + sync_result_builer(SyncResultBuilder) + ignored_users(set(str)): Set of users ignored by user. + + Returns: + Deferred(tuple): Returns a tuple of the form: + `([RoomSyncResultBuilder], [InvitedSyncResult], newly_joined_rooms)` + """ user_id = sync_result_builer.sync_config.user.to_string() since_token = sync_result_builer.since_token now_token = sync_result_builer.now_token @@ -803,6 +854,17 @@ class SyncHandler(object): @defer.inlineCallbacks def _get_all_rooms(self, sync_result_builer, ignored_users): + """Returns entries for all rooms for the user. + + Args: + sync_result_builer(SyncResultBuilder) + ignored_users(set(str)): Set of users ignored by user. + + Returns: + Deferred(tuple): Returns a tuple of the form: + `([RoomSyncResultBuilder], [InvitedSyncResult], [])` + """ + user_id = sync_result_builer.sync_config.user.to_string() since_token = sync_result_builer.since_token now_token = sync_result_builer.now_token @@ -865,6 +927,20 @@ class SyncHandler(object): def _generate_room_entry(self, sync_result_builer, ignored_users, room_builder, ephemeral, tags, account_data, always_include=False): + """Populates the `joined` and `archived` section of `sync_result_builer` + based on the `room_builder`. + + Args: + sync_result_builer(SyncResultBuilder) + ignored_users(set(str)): Set of users ignored by user. + room_builder(RoomSyncResultBuilder) + ephemeral(list): List of new ephemeral events for room + tags(list): List of *all* tags for room, or None if there has been + no change. + account_data(list): List of new account data for room + always_include(bool): Always include this room in the sync response, + even if empty. + """ since_token = sync_result_builer.since_token now_token = sync_result_builer.now_token sync_config = sync_result_builer.sync_config @@ -998,7 +1074,15 @@ def _calculate_state(timeline_contains, timeline_start, previous, current): class SyncResultBuilder(object): + "Used to help build up a new SyncResult for a user" def __init__(self, sync_config, full_state, since_token, now_token): + """ + Args: + sync_config(SyncConfig) + full_state(bool): The full_state flag as specified by user + since_token(StreamToken): The token supplied by user, or None. + now_token(StreamToken): The token to sync up to. + """ self.sync_config = sync_config self.full_state = full_state self.since_token = since_token @@ -1012,8 +1096,22 @@ class SyncResultBuilder(object): class RoomSyncResultBuilder(object): + """Stores information needed to create either a `JoinedSyncResult` or + `ArchivedSyncResult`. + """ def __init__(self, room_id, rtype, events, newly_joined, full_state, since_token, upto_token): + """ + Args: + room_id(str) + rtype(str): One of `"joined"` or `"archived"` + events(list): List of events to include in the room, (more events + may be added when generating result). + newly_joined(bool): If the user has newly joined the room + full_state(bool): Whether the full state should be sent in result + since_token(StreamToken): Earliest point to return events from, or None + upto_token(StreamToken): Latest point to return events from. + """ self.room_id = room_id self.rtype = rtype self.events = events -- cgit 1.5.1 From 79bea8ab9a4c4cb4a0907784603698731424c00a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2016 10:22:24 +0100 Subject: Inline function. Make load_filtered_recents private --- synapse/handlers/sync.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 271dd4c147..3aa4432d68 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -203,17 +203,6 @@ class SyncHandler(object): rules = format_push_rules_for_user(user, rawrules, enabled_map) defer.returnValue(rules) - def account_data_for_user(self, account_data): - account_data_events = [] - - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - return account_data_events - @defer.inlineCallbacks def ephemeral_by_room(self, sync_config, now_token, since_token=None): """Get the ephemeral events for each room the user is in @@ -277,8 +266,8 @@ class SyncHandler(object): defer.returnValue((now_token, ephemeral_by_room)) @defer.inlineCallbacks - def load_filtered_recents(self, room_id, sync_config, now_token, - since_token=None, recents=None, newly_joined_room=False): + def _load_filtered_recents(self, room_id, sync_config, now_token, + since_token=None, recents=None, newly_joined_room=False): """ Returns: a Deferred TimelineBatch @@ -586,9 +575,10 @@ class SyncHandler(object): sync_config.user ) - account_data_for_user = sync_config.filter_collection.filter_account_data( - self.account_data_for_user(account_data) - ) + account_data_for_user = sync_config.filter_collection.filter_account_data([ + {"type": account_data_type, "content": content} + for account_data_type, content in account_data.items() + ]) sync_result_builer.account_data = account_data_for_user @@ -956,7 +946,7 @@ class SyncHandler(object): since_token = room_builder.since_token upto_token = room_builder.upto_token - batch = yield self.load_filtered_recents( + batch = yield self._load_filtered_recents( room_id, sync_config, now_token=upto_token, since_token=since_token, -- cgit 1.5.1 From be2c67738640ff88fcf63701c9e3d82afc385e47 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2016 10:53:03 +0100 Subject: Spell builder correctly --- synapse/handlers/sync.py | 126 +++++++++++++++++++++++------------------------ 1 file changed, 63 insertions(+), 63 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 3aa4432d68..12a4cc8b57 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -505,50 +505,50 @@ class SyncHandler(object): # Always use the `now_token` in `SyncResultBuilder` now_token = yield self.event_sources.get_current_token() - sync_result_builer = SyncResultBuilder( + sync_result_builder = SyncResultBuilder( sync_config, full_state, since_token=since_token, now_token=now_token, ) account_data_by_room = yield self._generate_sync_entry_for_account_data( - sync_result_builer + sync_result_builder ) res = yield self._generate_sync_entry_for_rooms( - sync_result_builer, account_data_by_room + sync_result_builder, account_data_by_room ) newly_joined_rooms, newly_joined_users = res yield self._generate_sync_entry_for_presence( - sync_result_builer, newly_joined_rooms, newly_joined_users + sync_result_builder, newly_joined_rooms, newly_joined_users ) defer.returnValue(SyncResult( - presence=sync_result_builer.presence, - account_data=sync_result_builer.account_data, - joined=sync_result_builer.joined, - invited=sync_result_builer.invited, - archived=sync_result_builer.archived, - next_batch=sync_result_builer.now_token, + presence=sync_result_builder.presence, + account_data=sync_result_builder.account_data, + joined=sync_result_builder.joined, + invited=sync_result_builder.invited, + archived=sync_result_builder.archived, + next_batch=sync_result_builder.now_token, )) @defer.inlineCallbacks - def _generate_sync_entry_for_account_data(self, sync_result_builer): + def _generate_sync_entry_for_account_data(self, sync_result_builder): """Generates the account data portion of the sync response. Populates - `sync_result_builer` with the result. + `sync_result_builder` with the result. Args: - sync_result_builer(SyncResultBuilder) + sync_result_builder(SyncResultBuilder) Returns: Deferred(dict): A dictionary containing the per room account data. """ - sync_config = sync_result_builer.sync_config - user_id = sync_result_builer.sync_config.user.to_string() - since_token = sync_result_builer.since_token + sync_config = sync_result_builder.sync_config + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token - if since_token and not sync_result_builer.full_state: + if since_token and not sync_result_builder.full_state: account_data, account_data_by_room = ( yield self.store.get_updated_account_data_for_user( user_id, @@ -580,31 +580,31 @@ class SyncHandler(object): for account_data_type, content in account_data.items() ]) - sync_result_builer.account_data = account_data_for_user + sync_result_builder.account_data = account_data_for_user defer.returnValue(account_data_by_room) @defer.inlineCallbacks - def _generate_sync_entry_for_presence(self, sync_result_builer, newly_joined_rooms, + def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms, newly_joined_users): """Generates the presence portion of the sync response. Populates the - `sync_result_builer` with the result. + `sync_result_builder` with the result. Args: - sync_result_builer(SyncResultBuilder) + sync_result_builder(SyncResultBuilder) newly_joined_rooms(list): List of rooms that the user has joined since the last sync (or empty if an initial sync) newly_joined_users(list): List of users that have joined rooms since the last sync (or empty if an initial sync) """ - now_token = sync_result_builer.now_token - sync_config = sync_result_builer.sync_config - user = sync_result_builer.sync_config.user + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config + user = sync_result_builder.sync_config.user presence_source = self.event_sources.sources["presence"] - since_token = sync_result_builer.since_token - if since_token and not sync_result_builer.full_state: + since_token = sync_result_builder.since_token + if since_token and not sync_result_builder.full_state: presence_key = since_token.presence_key else: presence_key = None @@ -614,7 +614,7 @@ class SyncHandler(object): from_key=presence_key, is_guest=sync_config.is_guest, ) - sync_result_builer.now_token = now_token.copy_and_replace( + sync_result_builder.now_token = now_token.copy_and_replace( "presence_key", presence_key ) @@ -634,27 +634,27 @@ class SyncHandler(object): presence ) - sync_result_builer.presence = presence + sync_result_builder.presence = presence @defer.inlineCallbacks - def _generate_sync_entry_for_rooms(self, sync_result_builer, account_data_by_room): + def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room): """Generates the rooms portion of the sync response. Populates the - `sync_result_builer` with the result. + `sync_result_builder` with the result. Args: - sync_result_builer(SyncResultBuilder) + sync_result_builder(SyncResultBuilder) account_data_by_room(dict): Dictionary of per room account data Returns: Deferred(tuple): Returns a 2-tuple of `(newly_joined_rooms, newly_joined_users)` """ - user_id = sync_result_builer.sync_config.user.to_string() + user_id = sync_result_builder.sync_config.user.to_string() now_token, ephemeral_by_room = yield self.ephemeral_by_room( - sync_result_builer.sync_config, sync_result_builer.now_token + sync_result_builder.sync_config, sync_result_builder.now_token ) - sync_result_builer.now_token = now_token + sync_result_builder.now_token = now_token ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id=user_id, @@ -665,38 +665,38 @@ class SyncHandler(object): else: ignored_users = frozenset() - if sync_result_builer.since_token: - res = yield self._get_rooms_changed(sync_result_builer, ignored_users) + if sync_result_builder.since_token: + res = yield self._get_rooms_changed(sync_result_builder, ignored_users) room_entries, invited, newly_joined_rooms = res tags_by_room = yield self.store.get_updated_tags( user_id, - sync_result_builer.since_token.account_data_key, + sync_result_builder.since_token.account_data_key, ) else: - res = yield self._get_all_rooms(sync_result_builer, ignored_users) + res = yield self._get_all_rooms(sync_result_builder, ignored_users) room_entries, invited, newly_joined_rooms = res tags_by_room = yield self.store.get_tags_for_user(user_id) def handle_room_entries(room_entry): return self._generate_room_entry( - sync_result_builer, + sync_result_builder, ignored_users, room_entry, ephemeral=ephemeral_by_room.get(room_entry.room_id, []), tags=tags_by_room.get(room_entry.room_id), account_data=account_data_by_room.get(room_entry.room_id, {}), - always_include=sync_result_builer.full_state, + always_include=sync_result_builder.full_state, ) yield concurrently_execute(handle_room_entries, room_entries, 10) - sync_result_builer.invited.extend(invited) + sync_result_builder.invited.extend(invited) # Now we want to get any newly joined users newly_joined_users = set() - for joined_sync in sync_result_builer.joined: + for joined_sync in sync_result_builder.joined: it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values()) for event in it: if event.type == EventTypes.Member: @@ -706,21 +706,21 @@ class SyncHandler(object): defer.returnValue((newly_joined_rooms, newly_joined_users)) @defer.inlineCallbacks - def _get_rooms_changed(self, sync_result_builer, ignored_users): + def _get_rooms_changed(self, sync_result_builder, ignored_users): """Gets the the changes that have happened since the last sync. Args: - sync_result_builer(SyncResultBuilder) + sync_result_builder(SyncResultBuilder) ignored_users(set(str)): Set of users ignored by user. Returns: Deferred(tuple): Returns a tuple of the form: `([RoomSyncResultBuilder], [InvitedSyncResult], newly_joined_rooms)` """ - user_id = sync_result_builer.sync_config.user.to_string() - since_token = sync_result_builer.since_token - now_token = sync_result_builer.now_token - sync_config = sync_result_builer.sync_config + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config assert since_token @@ -843,11 +843,11 @@ class SyncHandler(object): defer.returnValue((room_entries, invited, newly_joined_rooms)) @defer.inlineCallbacks - def _get_all_rooms(self, sync_result_builer, ignored_users): + def _get_all_rooms(self, sync_result_builder, ignored_users): """Returns entries for all rooms for the user. Args: - sync_result_builer(SyncResultBuilder) + sync_result_builder(SyncResultBuilder) ignored_users(set(str)): Set of users ignored by user. Returns: @@ -855,10 +855,10 @@ class SyncHandler(object): `([RoomSyncResultBuilder], [InvitedSyncResult], [])` """ - user_id = sync_result_builer.sync_config.user.to_string() - since_token = sync_result_builer.since_token - now_token = sync_result_builer.now_token - sync_config = sync_result_builer.sync_config + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config membership_list = ( Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN @@ -914,14 +914,14 @@ class SyncHandler(object): defer.returnValue((room_entries, invited, [])) @defer.inlineCallbacks - def _generate_room_entry(self, sync_result_builer, ignored_users, + def _generate_room_entry(self, sync_result_builder, ignored_users, room_builder, ephemeral, tags, account_data, always_include=False): - """Populates the `joined` and `archived` section of `sync_result_builer` + """Populates the `joined` and `archived` section of `sync_result_builder` based on the `room_builder`. Args: - sync_result_builer(SyncResultBuilder) + sync_result_builder(SyncResultBuilder) ignored_users(set(str)): Set of users ignored by user. room_builder(RoomSyncResultBuilder) ephemeral(list): List of new ephemeral events for room @@ -931,9 +931,9 @@ class SyncHandler(object): always_include(bool): Always include this room in the sync response, even if empty. """ - since_token = sync_result_builer.since_token - now_token = sync_result_builer.now_token - sync_config = sync_result_builer.sync_config + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config room_id = room_builder.room_id events = room_builder.events @@ -941,7 +941,7 @@ class SyncHandler(object): full_state = ( room_builder.full_state or newly_joined - or sync_result_builer.full_state + or sync_result_builder.full_state ) since_token = room_builder.since_token upto_token = room_builder.upto_token @@ -1001,7 +1001,7 @@ class SyncHandler(object): unread_notifications["notification_count"] = notifs["notify_count"] unread_notifications["highlight_count"] = notifs["highlight_count"] - sync_result_builer.joined.append(room_sync) + sync_result_builder.joined.append(room_sync) elif room_builder.rtype == "archived": room_sync = ArchivedSyncResult( room_id=room_id, @@ -1010,7 +1010,7 @@ class SyncHandler(object): account_data=account_data, ) if room_sync or always_include: - sync_result_builer.archived.append(room_sync) + sync_result_builder.archived.append(room_sync) else: raise Exception("Unrecognized rtype: %r", room_builder.rtype) -- cgit 1.5.1 From b08ad0389e0e412798f32d9f5656db483238173d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2016 11:04:35 +0100 Subject: Only include non-offline presence in initial sync --- synapse/handlers/sync.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'synapse/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 12a4cc8b57..8143171c11 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -606,13 +606,16 @@ class SyncHandler(object): since_token = sync_result_builder.since_token if since_token and not sync_result_builder.full_state: presence_key = since_token.presence_key + include_offline = True else: presence_key = None + include_offline = False presence, presence_key = yield presence_source.get_new_events( user=user, from_key=presence_key, is_guest=sync_config.is_guest, + include_offline=include_offline, ) sync_result_builder.now_token = now_token.copy_and_replace( "presence_key", presence_key -- cgit 1.5.1 From 1c5ed2a19ba6da4478c54df072e623598c0ea60d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2016 11:21:34 +0100 Subject: Only work out newly_joined_users for incremental sync --- synapse/handlers/sync.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 8143171c11..476dcf38e2 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -699,12 +699,15 @@ class SyncHandler(object): # Now we want to get any newly joined users newly_joined_users = set() - for joined_sync in sync_result_builder.joined: - it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values()) - for event in it: - if event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - newly_joined_users.add(event.state_key) + if sync_result_builder.since_token: + for joined_sync in sync_result_builder.joined: + it = itertools.chain( + joined_sync.timeline.events, joined_sync.state.values() + ) + for event in it: + if event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + newly_joined_users.add(event.state_key) defer.returnValue((newly_joined_rooms, newly_joined_users)) -- cgit 1.5.1 From 69003039973169097592359f3f34fc32c5bbaeb0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2016 11:44:55 +0100 Subject: Don't send down all ephemeral events --- synapse/handlers/sync.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 476dcf38e2..6f7dd45ef3 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -655,7 +655,9 @@ class SyncHandler(object): user_id = sync_result_builder.sync_config.user.to_string() now_token, ephemeral_by_room = yield self.ephemeral_by_room( - sync_result_builder.sync_config, sync_result_builder.now_token + sync_result_builder.sync_config, + now_token=sync_result_builder.now_token, + since_token=sync_result_builder.since_token, ) sync_result_builder.now_token = now_token -- cgit 1.5.1 From faad233ea61cfff2c377609fa1d3c64d39f8a039 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 May 2016 14:00:43 +0100 Subject: Change short circuit path --- synapse/handlers/sync.py | 42 ++++++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 16 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6f7dd45ef3..3b89582d79 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -273,23 +273,14 @@ class SyncHandler(object): a Deferred TimelineBatch """ with Measure(self.clock, "load_filtered_recents"): - filtering_factor = 2 timeline_limit = sync_config.filter_collection.timeline_limit() - load_limit = max(timeline_limit * filtering_factor, 10) - max_repeat = 5 # Only try a few times per room, otherwise - room_key = now_token.room_key - end_key = room_key if recents is None or newly_joined_room or timeline_limit < len(recents): limited = True else: limited = False - if since_token: - if not now_token.is_after(since_token): - limited = False - - if recents is not None: + if recents: recents = sync_config.filter_collection.filter_room_timeline(recents) recents = yield filter_events_for_client( self.store, @@ -299,6 +290,19 @@ class SyncHandler(object): else: recents = [] + if not limited: + defer.returnValue(TimelineBatch( + events=recents, + prev_batch=now_token, + limited=False + )) + + filtering_factor = 2 + load_limit = max(timeline_limit * filtering_factor, 10) + max_repeat = 5 # Only try a few times per room, otherwise + room_key = now_token.room_key + end_key = room_key + since_key = None if since_token and not newly_joined_room: since_key = since_token.room_key @@ -939,18 +943,24 @@ class SyncHandler(object): always_include(bool): Always include this room in the sync response, even if empty. """ - since_token = sync_result_builder.since_token - now_token = sync_result_builder.now_token - sync_config = sync_result_builder.sync_config - - room_id = room_builder.room_id - events = room_builder.events newly_joined = room_builder.newly_joined full_state = ( room_builder.full_state or newly_joined or sync_result_builder.full_state ) + events = room_builder.events + + # We want to shortcut out as early as possible. + if not (always_include or account_data or ephemeral or full_state): + if events == [] and tags is None: + return + + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config + + room_id = room_builder.room_id since_token = room_builder.since_token upto_token = room_builder.upto_token -- cgit 1.5.1 From cc84f7cb8e7f379ceb23a8bc719465b989199b20 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 27 May 2016 10:35:15 +0100 Subject: Send down correct error response if user not found --- synapse/handlers/auth.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 68d0d78fc6..26c865e171 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,7 +18,7 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.constants import LoginType from synapse.types import UserID -from synapse.api.errors import AuthError, LoginError, Codes +from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.util.async import run_on_reactor from twisted.web.client import PartialDownloadError @@ -563,7 +563,12 @@ class AuthHandler(BaseHandler): except_access_token_ids = [requester.access_token_id] if requester else [] - yield self.store.user_set_password_hash(user_id, password_hash) + try: + yield self.store.user_set_password_hash(user_id, password_hash) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) + raise e yield self.store.user_delete_access_tokens( user_id, except_access_token_ids ) -- cgit 1.5.1 From 887c6e6f052e1dc5e61a0b4bade8e7bd3a63e275 Mon Sep 17 00:00:00 2001 From: David Baker Date: Tue, 31 May 2016 11:05:16 +0100 Subject: Split out the room list handler So I can use it from federation bits without pulling in all the handlers. --- synapse/handlers/__init__.py | 3 +-- synapse/rest/client/v1/room.py | 2 +- synapse/server.py | 5 +++++ 3 files changed, 7 insertions(+), 3 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 9442ae6f1d..0ac5d3da3a 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -17,7 +17,7 @@ from synapse.appservice.scheduler import AppServiceScheduler from synapse.appservice.api import ApplicationServiceApi from .register import RegistrationHandler from .room import ( - RoomCreationHandler, RoomListHandler, RoomContextHandler, + RoomCreationHandler, RoomContextHandler, ) from .room_member import RoomMemberHandler from .message import MessageHandler @@ -50,7 +50,6 @@ class Handlers(object): self.event_handler = EventHandler(hs) self.federation_handler = FederationHandler(hs) self.profile_handler = ProfileHandler(hs) - self.room_list_handler = RoomListHandler(hs) self.directory_handler = DirectoryHandler(hs) self.admin_handler = AdminHandler(hs) self.receipts_handler = ReceiptsHandler(hs) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 644aa4e513..2d22bbdaa3 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -279,7 +279,7 @@ class PublicRoomListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - handler = self.handlers.room_list_handler + handler = self.hs.get_room_list_handler() data = yield handler.get_public_room_list() defer.returnValue((200, data)) diff --git a/synapse/server.py b/synapse/server.py index 01f828819f..bfd5608b7d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -30,6 +30,7 @@ from synapse.handlers import Handlers from synapse.handlers.presence import PresenceHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler +from synapse.handlers.room import RoomListHandler from synapse.state import StateHandler from synapse.storage import DataStore from synapse.util import Clock @@ -84,6 +85,7 @@ class HomeServer(object): 'presence_handler', 'sync_handler', 'typing_handler', + 'room_list_handler', 'notifier', 'distributor', 'client_resource', @@ -179,6 +181,9 @@ class HomeServer(object): def build_sync_handler(self): return SyncHandler(self) + def build_room_list_handler(self): + return RoomListHandler(self) + def build_event_sources(self): return EventSources(self) -- cgit 1.5.1 From c626fc576a87f401c594cbb070d5b0000e45b4e1 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 31 May 2016 13:53:48 +0100 Subject: Move the AS handler out of the Handlers object. Access it directly from the homeserver itself. It already wasn't inheriting from BaseHandler storing it on the Handlers object was already somewhat dubious. --- synapse/appservice/scheduler.py | 14 +++++++------- synapse/handlers/__init__.py | 11 ----------- synapse/handlers/appservice.py | 15 +++++---------- synapse/handlers/directory.py | 3 ++- synapse/notifier.py | 10 ++++------ synapse/server.py | 15 +++++++++++++++ tests/handlers/test_appservice.py | 6 +++--- 7 files changed, 36 insertions(+), 38 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 47a4e9f864..9afc8fd754 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -56,22 +56,22 @@ import logging logger = logging.getLogger(__name__) -class AppServiceScheduler(object): +class ApplicationServiceScheduler(object): """ Public facing API for this module. Does the required DI to tie the components together. This also serves as the "event_pool", which in this case is a simple array. """ - def __init__(self, clock, store, as_api): - self.clock = clock - self.store = store - self.as_api = as_api + def __init__(self, hs): + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.as_api = hs.get_application_service_api() def create_recoverer(service, callback): - return _Recoverer(clock, store, as_api, service, callback) + return _Recoverer(self.clock, self.store, self.as_api, service, callback) self.txn_ctrl = _TransactionController( - clock, store, as_api, create_recoverer + self.clock, self.store, self.as_api, create_recoverer ) self.queuer = _ServiceQueuer(self.txn_ctrl) diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 0ac5d3da3a..c0069e23d6 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.appservice.scheduler import AppServiceScheduler -from synapse.appservice.api import ApplicationServiceApi from .register import RegistrationHandler from .room import ( RoomCreationHandler, RoomContextHandler, @@ -26,7 +24,6 @@ from .federation import FederationHandler from .profile import ProfileHandler from .directory import DirectoryHandler from .admin import AdminHandler -from .appservice import ApplicationServicesHandler from .auth import AuthHandler from .identity import IdentityHandler from .receipts import ReceiptsHandler @@ -53,14 +50,6 @@ class Handlers(object): self.directory_handler = DirectoryHandler(hs) self.admin_handler = AdminHandler(hs) self.receipts_handler = ReceiptsHandler(hs) - asapi = ApplicationServiceApi(hs) - self.appservice_handler = ApplicationServicesHandler( - hs, asapi, AppServiceScheduler( - clock=hs.get_clock(), - store=hs.get_datastore(), - as_api=asapi - ) - ) self.auth_handler = AuthHandler(hs) self.identity_handler = IdentityHandler(hs) self.search_handler = SearchHandler(hs) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 75fc74c797..051ccdb380 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -17,7 +17,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.appservice import ApplicationService -from synapse.types import UserID import logging @@ -35,16 +34,13 @@ def log_failure(failure): ) -# NB: Purposefully not inheriting BaseHandler since that contains way too much -# setup code which this handler does not need or use. This makes testing a lot -# easier. class ApplicationServicesHandler(object): - def __init__(self, hs, appservice_api, appservice_scheduler): + def __init__(self, hs): self.store = hs.get_datastore() - self.hs = hs - self.appservice_api = appservice_api - self.scheduler = appservice_scheduler + self.is_mine_id = hs.is_mine_id + self.appservice_api = hs.get_application_service_api() + self.scheduler = hs.get_application_service_scheduler() self.started_scheduler = False @defer.inlineCallbacks @@ -169,8 +165,7 @@ class ApplicationServicesHandler(object): @defer.inlineCallbacks def _is_unknown_user(self, user_id): - user = UserID.from_string(user_id) - if not self.hs.is_mine(user): + if not self.is_mine_id(user_id): # we don't know if they are unknown or not since it isn't one of our # users. We can't poke ASes. defer.returnValue(False) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 8eeb225811..4bea7f2b19 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -33,6 +33,7 @@ class DirectoryHandler(BaseHandler): super(DirectoryHandler, self).__init__(hs) self.state = hs.get_state_handler() + self.appservice_handler = hs.get_application_service_handler() self.federation = hs.get_replication_layer() self.federation.register_query_handler( @@ -281,7 +282,7 @@ class DirectoryHandler(BaseHandler): ) if not result: # Query AS to see if it exists - as_handler = self.hs.get_handlers().appservice_handler + as_handler = self.appservice_handler result = yield as_handler.query_room_alias_exists(room_alias) defer.returnValue(result) diff --git a/synapse/notifier.py b/synapse/notifier.py index 33b79c0ec7..cbec4d30ae 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -140,8 +140,6 @@ class Notifier(object): UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 def __init__(self, hs): - self.hs = hs - self.user_to_user_stream = {} self.room_to_user_streams = {} self.appservice_to_user_streams = {} @@ -151,6 +149,8 @@ class Notifier(object): self.pending_new_room_events = [] self.clock = hs.get_clock() + self.appservice_handler = hs.get_application_service_handler() + self.state_handler = hs.get_state_handler() hs.get_distributor().observe( "user_joined_room", self._user_joined_room @@ -232,9 +232,7 @@ class Notifier(object): def _on_new_room_event(self, event, room_stream_id, extra_users=[]): """Notify any user streams that are interested in this room event""" # poke any interested application service. - self.hs.get_handlers().appservice_handler.notify_interested_services( - event - ) + self.appservice_handler.notify_interested_services(event) app_streams = set() @@ -449,7 +447,7 @@ class Notifier(object): @defer.inlineCallbacks def _is_world_readable(self, room_id): - state = yield self.hs.get_state_handler().get_current_state( + state = yield self.state_handler.get_current_state( room_id, EventTypes.RoomHistoryVisibility ) diff --git a/synapse/server.py b/synapse/server.py index bfd5608b7d..7cf22b1eea 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -22,6 +22,8 @@ from twisted.web.client import BrowserLikePolicyForHTTPS from twisted.enterprise import adbapi +from synapse.appservice.scheduler import ApplicationServiceScheduler +from synapse.appservice.api import ApplicationServiceApi from synapse.federation import initialize_http_replication from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.notifier import Notifier @@ -31,6 +33,7 @@ from synapse.handlers.presence import PresenceHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler from synapse.handlers.room import RoomListHandler +from synapse.handlers.appservice import ApplicationServicesHandler from synapse.state import StateHandler from synapse.storage import DataStore from synapse.util import Clock @@ -86,6 +89,9 @@ class HomeServer(object): 'sync_handler', 'typing_handler', 'room_list_handler', + 'application_service_api', + 'application_service_scheduler', + 'application_service_handler', 'notifier', 'distributor', 'client_resource', @@ -184,6 +190,15 @@ class HomeServer(object): def build_room_list_handler(self): return RoomListHandler(self) + def build_application_service_api(self): + return ApplicationServiceApi(self) + + def build_application_service_scheduler(self): + return ApplicationServiceScheduler(self) + + def build_application_service_handler(self): + return ApplicationServicesHandler(self) + def build_event_sources(self): return EventSources(self) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 7ddbbb9b4a..a884c95f8d 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -30,9 +30,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_scheduler = Mock() hs = Mock() hs.get_datastore = Mock(return_value=self.mock_store) - self.handler = ApplicationServicesHandler( - hs, self.mock_as_api, self.mock_scheduler - ) + hs.get_application_service_api = Mock(return_value=self.mock_as_api) + hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler) + self.handler = ApplicationServicesHandler(hs) @defer.inlineCallbacks def test_notify_interested_services(self): -- cgit 1.5.1 From d240796dedcfae1f6929c1501e7e335df417cfaf Mon Sep 17 00:00:00 2001 From: David Baker Date: Tue, 31 May 2016 17:20:07 +0100 Subject: Basic, un-cached support for secondary_directory_servers --- synapse/federation/federation_client.py | 21 +++++++++++++++++++++ synapse/federation/transport/client.py | 12 ++++++++++++ synapse/federation/transport/server.py | 2 +- synapse/handlers/room.py | 33 ++++++++++++++++++++++++++++++++- synapse/rest/client/v1/room.py | 3 ++- 5 files changed, 68 insertions(+), 3 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 37ee469fa2..ba8d71c050 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -24,6 +24,7 @@ from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, ) from synapse.util import unwrapFirstError +from synapse.util.async import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.events import FrozenEvent @@ -550,6 +551,26 @@ class FederationClient(FederationBase): raise RuntimeError("Failed to send to any server.") + @defer.inlineCallbacks + def get_public_rooms(self, destinations): + results_by_server = {} + + @defer.inlineCallbacks + def _get_result(s): + if s == self.server_name: + defer.returnValue() + + try: + result = yield self.transport_layer.get_public_rooms(s) + results_by_server[s] = result + except: + logger.exception("Error getting room list from server %r", s) + + + yield concurrently_execute(_get_result, destinations, 3) + + defer.returnValue(results_by_server) + @defer.inlineCallbacks def query_auth(self, destination, room_id, event_id, local_auth): """ diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index cd2841c4db..ebb698e278 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -224,6 +224,18 @@ class TransportLayerClient(object): defer.returnValue(response) + @defer.inlineCallbacks + @log_function + def get_public_rooms(self, remote_server): + path = PREFIX + "/publicRooms" + + response = yield self.client.get_json( + destination=remote_server, + path=path, + ) + + defer.returnValue(response) + @defer.inlineCallbacks @log_function def exchange_third_party_invite(self, destination, room_id, event_dict): diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index f23c02efde..da9e7a326d 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -527,7 +527,7 @@ class PublicRoomList(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, request): - data = yield self.room_list_handler.get_public_room_list() + data = yield self.room_list_handler.get_local_public_room_list() defer.returnValue((200, data)) # Avoid doing remote HS authorization checks which are done by default by diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3d63b3c513..b0aa9fb511 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -345,7 +345,7 @@ class RoomListHandler(BaseHandler): super(RoomListHandler, self).__init__(hs) self.response_cache = ResponseCache() - def get_public_room_list(self): + def get_local_public_room_list(self): result = self.response_cache.get(()) if not result: result = self.response_cache.set((), self._get_public_room_list()) @@ -427,6 +427,37 @@ class RoomListHandler(BaseHandler): # FIXME (erikj): START is no longer a valid value defer.returnValue({"start": "START", "end": "END", "chunk": results}) + @defer.inlineCallbacks + def get_aggregated_public_room_list(self): + """ + Get the public room list from this server and the servers + specified in the secondary_directory_servers config option. + XXX: Pagination... + """ + federated_by_server = yield self.hs.get_replication_layer().get_public_rooms( + self.hs.config.secondary_directory_servers + ) + public_rooms = yield self.get_local_public_room_list() + + # keep track of which room IDs we've seen so we can de-dup + room_ids = set() + + # tag all the ones in our list with our server name. + # Also add the them to the de-deping set + for room in public_rooms['chunk']: + room["server_name"] = self.hs.hostname + room_ids.add(room["room_id"]) + + # Now add the results from federation + for server_name, server_result in federated_by_server.items(): + for room in server_result["chunk"]: + if room["room_id"] not in room_ids: + room["server_name"] = server_name + public_rooms["chunk"].append(room) + room_ids.add(room["room_id"]) + + defer.returnValue(public_rooms) + class RoomContextHandler(BaseHandler): @defer.inlineCallbacks diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 2d22bbdaa3..db52a1fc39 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -280,7 +280,8 @@ class PublicRoomListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): handler = self.hs.get_room_list_handler() - data = yield handler.get_public_room_list() + data = yield handler.get_aggregated_public_room_list() + defer.returnValue((200, data)) -- cgit 1.5.1 From 2a449fec4d03b7739269e832683905a18459e4c3 Mon Sep 17 00:00:00 2001 From: David Baker Date: Tue, 31 May 2016 18:27:23 +0100 Subject: Add cache to remote room lists Poll for updates from remote servers, waiting for the poll if there's no cache entry. --- synapse/handlers/room.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b0aa9fb511..77063b021a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -36,6 +36,8 @@ import string logger = logging.getLogger(__name__) +REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 + id_server_scheme = "https://" @@ -344,6 +346,12 @@ class RoomListHandler(BaseHandler): def __init__(self, hs): super(RoomListHandler, self).__init__(hs) self.response_cache = ResponseCache() + self.remote_list_request_cache = ResponseCache() + self.remote_list_cache = {} + self.fetch_looping_call = hs.get_clock().looping_call( + self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL + ) + self.fetch_all_remote_lists() def get_local_public_room_list(self): result = self.response_cache.get(()) @@ -427,6 +435,14 @@ class RoomListHandler(BaseHandler): # FIXME (erikj): START is no longer a valid value defer.returnValue({"start": "START", "end": "END", "chunk": results}) + @defer.inlineCallbacks + def fetch_all_remote_lists(self): + deferred = self.hs.get_replication_layer().get_public_rooms( + self.hs.config.secondary_directory_servers + ) + self.remote_list_request_cache.set((), deferred) + yield deferred + @defer.inlineCallbacks def get_aggregated_public_room_list(self): """ @@ -434,9 +450,19 @@ class RoomListHandler(BaseHandler): specified in the secondary_directory_servers config option. XXX: Pagination... """ - federated_by_server = yield self.hs.get_replication_layer().get_public_rooms( - self.hs.config.secondary_directory_servers - ) + # We return the results from out cache which is updated by a looping call, + # unless we're missing a cache entry, in which case wait for the result + # of the fetch if there's one in progress. If not, omit that server. + wait = False + for s in self.hs.config.secondary_directory_servers: + if s not in self.remote_list_cache: + logger.warn("No cached room list from %s: waiting for fetch", s) + wait = True + break + + if wait and self.remote_list_request_cache.get(()): + yield self.remote_list_request_cache.get(()) + public_rooms = yield self.get_local_public_room_list() # keep track of which room IDs we've seen so we can de-dup @@ -449,7 +475,7 @@ class RoomListHandler(BaseHandler): room_ids.add(room["room_id"]) # Now add the results from federation - for server_name, server_result in federated_by_server.items(): + for server_name, server_result in self.remote_list_cache.items(): for room in server_result["chunk"]: if room["room_id"] not in room_ids: room["server_name"] = server_name -- cgit 1.5.1 From e0deeff23eec099ad6bcb6e7170f524dc14982e4 Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 1 Jun 2016 17:58:58 +0100 Subject: Fix room list spidering --- synapse/handlers/room.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 77063b021a..9fd34588dd 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -441,7 +441,7 @@ class RoomListHandler(BaseHandler): self.hs.config.secondary_directory_servers ) self.remote_list_request_cache.set((), deferred) - yield deferred + self.remote_list_cache = yield deferred @defer.inlineCallbacks def get_aggregated_public_room_list(self): -- cgit 1.5.1 From a15ad608496fd29fb8bf289152c23adca822beca Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 2 Jun 2016 11:44:15 +0100 Subject: Email unsubscribing that may in theory, work Were it not for that fact that you can't use the base handler in the pusher because it pulls in the world. Comitting while I fix that on a different branch. --- synapse/handlers/auth.py | 5 +++++ synapse/push/emailpusher.py | 2 +- synapse/push/mailer.py | 21 ++++++++++++++++----- 3 files changed, 22 insertions(+), 6 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 26c865e171..200793b5ed 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -529,6 +529,11 @@ class AuthHandler(BaseHandler): macaroon.add_first_party_caveat("time < %d" % (expiry,)) return macaroon.serialize() + def generate_delete_pusher_token(self, user_id): + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = delete_pusher") + return macaroon.serialize() + def validate_short_term_login_token_and_get_user_id(self, login_token): try: macaroon = pymacaroons.Macaroon.deserialize(login_token) diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index a72cba8306..46d7c0434b 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -273,5 +273,5 @@ class EmailPusher(object): logger.info("Sending notif email for user %r", self.user_id) yield self.mailer.send_notification_mail( - self.user_id, self.email, push_actions, reason + self.app_id, self.user_id, self.email, push_actions, reason ) diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 3ae92d1574..95250bad7d 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -81,6 +81,7 @@ class Mailer(object): def __init__(self, hs): self.hs = hs self.store = self.hs.get_datastore() + self.handlers = self.hs.get_handlers() self.state_handler = self.hs.get_state_handler() loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir) self.app_name = self.hs.config.email_app_name @@ -95,7 +96,8 @@ class Mailer(object): ) @defer.inlineCallbacks - def send_notification_mail(self, user_id, email_address, push_actions, reason): + def send_notification_mail(self, app_id, user_id, email_address, + push_actions, reason): raw_from = email.utils.parseaddr(self.hs.config.email_notif_from)[1] raw_to = email.utils.parseaddr(email_address)[1] @@ -157,7 +159,7 @@ class Mailer(object): template_vars = { "user_display_name": user_display_name, - "unsubscribe_link": self.make_unsubscribe_link(), + "unsubscribe_link": self.make_unsubscribe_link(app_id, email_address), "summary_text": summary_text, "app_name": self.app_name, "rooms": rooms, @@ -423,9 +425,18 @@ class Mailer(object): notif['room_id'], notif['event_id'] ) - def make_unsubscribe_link(self): - # XXX: matrix.to - return "https://vector.im/#/settings" + def make_unsubscribe_link(self, app_id, email_address): + params = { + "access_token": self.handlers.auth.generate_delete_pusher_token(), + "app_id": app_id, + "pushkey": email_address, + } + + # XXX: make r0 once API is stable + return "%s_matrix/client/unstable/pushers/remove?%s" % ( + self.hs.config.public_baseurl, + urllib.urlencode(params), + ) def mxc_to_http_filter(self, value, width, height, resize_method="crop"): if value[0:6] != "mxc://": -- cgit 1.5.1 From 4a10510cd5aff790127a185ecefc83b881a717cc Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 2 Jun 2016 13:31:45 +0100 Subject: Split out the auth handler --- synapse/handlers/__init__.py | 2 -- synapse/handlers/register.py | 2 +- synapse/rest/client/v1/login.py | 11 ++++++----- synapse/rest/client/v2_alpha/account.py | 4 ++-- synapse/rest/client/v2_alpha/auth.py | 2 +- synapse/rest/client/v2_alpha/register.py | 2 +- synapse/rest/client/v2_alpha/tokenrefresh.py | 2 +- synapse/server.py | 5 +++++ tests/rest/client/v2_alpha/test_register.py | 2 +- tests/utils.py | 15 +++++---------- 10 files changed, 23 insertions(+), 24 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index c0069e23d6..d28e07f0d9 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -24,7 +24,6 @@ from .federation import FederationHandler from .profile import ProfileHandler from .directory import DirectoryHandler from .admin import AdminHandler -from .auth import AuthHandler from .identity import IdentityHandler from .receipts import ReceiptsHandler from .search import SearchHandler @@ -50,7 +49,6 @@ class Handlers(object): self.directory_handler = DirectoryHandler(hs) self.admin_handler = AdminHandler(hs) self.receipts_handler = ReceiptsHandler(hs) - self.auth_handler = AuthHandler(hs) self.identity_handler = IdentityHandler(hs) self.search_handler = SearchHandler(hs) self.room_context_handler = RoomContextHandler(hs) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 16f33f8371..bbc07b045e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -413,7 +413,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue((user_id, token)) def auth_handler(self): - return self.hs.get_handlers().auth_handler + return self.hs.get_auth_handler() @defer.inlineCallbacks def guest_access_token_for(self, medium, address, inviter_user_id): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 3b5544851b..8df9d10efa 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -58,6 +58,7 @@ class LoginRestServlet(ClientV1RestServlet): self.cas_required_attributes = hs.config.cas_required_attributes self.servername = hs.config.server_name self.http_client = hs.get_simple_http_client() + self.auth_handler = self.hs.get_auth_handler() def on_GET(self, request): flows = [] @@ -143,7 +144,7 @@ class LoginRestServlet(ClientV1RestServlet): user_id, self.hs.hostname ).to_string() - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_id, access_token, refresh_token = yield auth_handler.login_with_password( user_id=user_id, password=login_submission["password"]) @@ -160,7 +161,7 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def do_token_login(self, login_submission): token = login_submission['token'] - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) @@ -194,7 +195,7 @@ class LoginRestServlet(ClientV1RestServlet): raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_exists = yield auth_handler.does_user_exist(user_id) if user_exists: user_id, access_token, refresh_token = ( @@ -243,7 +244,7 @@ class LoginRestServlet(ClientV1RestServlet): raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_exists = yield auth_handler.does_user_exist(user_id) if user_exists: user_id, access_token, refresh_token = ( @@ -412,7 +413,7 @@ class CasTicketServlet(ClientV1RestServlet): raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_exists = yield auth_handler.does_user_exist(user_id) if not user_exists: user_id, _ = ( diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index c88c270537..9a84873a5f 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -35,7 +35,7 @@ class PasswordRestServlet(RestServlet): super(PasswordRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() @defer.inlineCallbacks def on_POST(self, request): @@ -97,7 +97,7 @@ class ThreepidRestServlet(RestServlet): self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 78181b7b18..58d3cad6a1 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet): super(AuthRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler @defer.inlineCallbacks diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 1ecc02d94d..2088c316d1 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -49,7 +49,7 @@ class RegisterRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index a158c2209a..8270e8787f 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -38,7 +38,7 @@ class TokenRefreshRestServlet(RestServlet): body = parse_json_object_from_request(request) try: old_refresh_token = body["refresh_token"] - auth_handler = self.hs.get_handlers().auth_handler + auth_handler = self.hs.get_auth_handler() (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( old_refresh_token, auth_handler.generate_refresh_token) new_access_token = yield auth_handler.issue_access_token(user_id) diff --git a/synapse/server.py b/synapse/server.py index 7cf22b1eea..dd4b81c658 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -33,6 +33,7 @@ from synapse.handlers.presence import PresenceHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler from synapse.handlers.room import RoomListHandler +from synapse.handlers.auth import AuthHandler from synapse.handlers.appservice import ApplicationServicesHandler from synapse.state import StateHandler from synapse.storage import DataStore @@ -89,6 +90,7 @@ class HomeServer(object): 'sync_handler', 'typing_handler', 'room_list_handler', + 'auth_handler', 'application_service_api', 'application_service_scheduler', 'application_service_handler', @@ -190,6 +192,9 @@ class HomeServer(object): def build_room_list_handler(self): return RoomListHandler(self) + def build_auth_handler(self): + return AuthHandler(self) + def build_application_service_api(self): return ApplicationServiceApi(self) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index affd42c015..cda0a2b27c 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -33,7 +33,6 @@ class RegisterRestServletTestCase(unittest.TestCase): # do the dance to hook it up to the hs global self.handlers = Mock( - auth_handler=self.auth_handler, registration_handler=self.registration_handler, identity_handler=self.identity_handler, login_handler=self.login_handler @@ -42,6 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.hs.hostname = "superbig~testing~thing.com" self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) + self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.config.enable_registration = True # init the thing we're testing diff --git a/tests/utils.py b/tests/utils.py index 006abedbc1..e19ae581e0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -81,16 +81,11 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): ) # bcrypt is far too slow to be doing in unit tests - def swap_out_hash_for_testing(old_build_handlers): - def build_handlers(): - handlers = old_build_handlers() - auth_handler = handlers.auth_handler - auth_handler.hash = lambda p: hashlib.md5(p).hexdigest() - auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h - return handlers - return build_handlers - - hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers) + # Need to let the HS build an auth handler and then mess with it + # because AuthHandler's constructor requires the HS, so we can't make one + # beforehand and pass it in to the HS's constructor (chicken / egg) + hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest() + hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h fed = kargs.get("resource_for_federation", None) if fed: -- cgit 1.5.1 From 70599ce9252997d32d0bf9f26a4e02c99bbe474d Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 2 Jun 2016 15:20:15 +0100 Subject: Allow external processes to mark a user as syncing. (#812) * Add infrastructure to the presence handler to track sync requests in external processes * Expire stale entries for dead external processes * Add an http endpoint for making users as syncing Add some docstrings and comments. * Fixes --- synapse/handlers/presence.py | 119 +++++++++++++++++++++++++++---- synapse/replication/presence_resource.py | 59 +++++++++++++++ synapse/replication/resource.py | 2 + tests/handlers/test_presence.py | 16 ++--- 4 files changed, 174 insertions(+), 22 deletions(-) create mode 100644 synapse/replication/presence_resource.py (limited to 'synapse/handlers') diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 37f57301fb..fc8538b41e 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -68,6 +68,10 @@ FEDERATION_TIMEOUT = 30 * 60 * 1000 # How often to resend presence to remote servers FEDERATION_PING_INTERVAL = 25 * 60 * 1000 +# How long we will wait before assuming that the syncs from an external process +# are dead. +EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000 + assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER @@ -158,10 +162,21 @@ class PresenceHandler(object): self.serial_to_user = {} self._next_serial = 1 - # Keeps track of the number of *ongoing* syncs. While this is non zero - # a user will never go offline. + # Keeps track of the number of *ongoing* syncs on this process. While + # this is non zero a user will never go offline. self.user_to_num_current_syncs = {} + # Keeps track of the number of *ongoing* syncs on other processes. + # While any sync is ongoing on another process the user will never + # go offline. + # Each process has a unique identifier and an update frequency. If + # no update is received from that process within the update period then + # we assume that all the sync requests on that process have stopped. + # Stored as a dict from process_id to set of user_id, and a dict of + # process_id to millisecond timestamp last updated. + self.external_process_to_current_syncs = {} + self.external_process_last_updated_ms = {} + # Start a LoopingCall in 30s that fires every 5s. # The initial delay is to allow disconnected clients a chance to # reconnect before we treat them as offline. @@ -272,13 +287,26 @@ class PresenceHandler(object): # Fetch the list of users that *may* have timed out. Things may have # changed since the timeout was set, so we won't necessarily have to # take any action. - users_to_check = self.wheel_timer.fetch(now) + users_to_check = set(self.wheel_timer.fetch(now)) + + # Check whether the lists of syncing processes from an external + # process have expired. + expired_process_ids = [ + process_id for process_id, last_update + in self.external_process_last_update.items() + if now - last_update > EXTERNAL_PROCESS_EXPIRY + ] + for process_id in expired_process_ids: + users_to_check.update( + self.external_process_to_current_syncs.pop(process_id, ()) + ) + self.external_process_last_update.pop(process_id) states = [ self.user_to_current_state.get( user_id, UserPresenceState.default(user_id) ) - for user_id in set(users_to_check) + for user_id in users_to_check ] timers_fired_counter.inc_by(len(states)) @@ -286,7 +314,7 @@ class PresenceHandler(object): changes = handle_timeouts( states, is_mine_fn=self.is_mine_id, - user_to_num_current_syncs=self.user_to_num_current_syncs, + syncing_users=self.get_syncing_users(), now=now, ) @@ -363,6 +391,73 @@ class PresenceHandler(object): defer.returnValue(_user_syncing()) + def get_currently_syncing_users(self): + """Get the set of user ids that are currently syncing on this HS. + Returns: + set(str): A set of user_id strings. + """ + syncing_user_ids = { + user_id for user_id, count in self.user_to_num_current_syncs.items() + if count + } + syncing_user_ids.update(self.external_process_to_current_syncs.values()) + return syncing_user_ids + + @defer.inlineCallbacks + def update_external_syncs(self, process_id, syncing_user_ids): + """Update the syncing users for an external process + + Args: + process_id(str): An identifier for the process the users are + syncing against. This allows synapse to process updates + as user start and stop syncing against a given process. + syncing_user_ids(set(str)): The set of user_ids that are + currently syncing on that server. + """ + + # Grab the previous list of user_ids that were syncing on that process + prev_syncing_user_ids = ( + self.external_process_to_current_syncs.get(process_id, set()) + ) + # Grab the current presence state for both the users that are syncing + # now and the users that were syncing before this update. + prev_states = yield self.current_state_for_users( + syncing_user_ids | prev_syncing_user_ids + ) + updates = [] + time_now_ms = self.clock.time_msec() + + # For each new user that is syncing check if we need to mark them as + # being online. + for new_user_id in syncing_user_ids - prev_syncing_user_ids: + prev_state = prev_states[new_user_id] + if prev_state.state == PresenceState.OFFLINE: + updates.append(prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=time_now_ms, + last_user_sync_ts=time_now_ms, + )) + else: + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=time_now_ms, + )) + + # For each user that is still syncing or stopped syncing update the + # last sync time so that we will correctly apply the grace period when + # they stop syncing. + for old_user_id in prev_syncing_user_ids: + prev_state = prev_states[old_user_id] + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=time_now_ms, + )) + + yield self._update_states(updates) + + # Update the last updated time for the process. We expire the entries + # if we don't receive an update in the given timeframe. + self.external_process_last_updated_ms[process_id] = self.clock.time_msec() + self.external_process_to_current_syncs[process_id] = syncing_user_ids + @defer.inlineCallbacks def current_state_for_user(self, user_id): """Get the current presence state for a user. @@ -935,15 +1030,14 @@ class PresenceEventSource(object): return self.get_new_events(user, from_key=None, include_offline=False) -def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now): +def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): """Checks the presence of users that have timed out and updates as appropriate. Args: user_states(list): List of UserPresenceState's to check. is_mine_fn (fn): Function that returns if a user_id is ours - user_to_num_current_syncs (dict): Mapping of user_id to number of currently - active syncs. + syncing_user_ids (set): Set of user_ids with active syncs. now (int): Current time in ms. Returns: @@ -954,21 +1048,20 @@ def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now): for state in user_states: is_mine = is_mine_fn(state.user_id) - new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now) + new_state = handle_timeout(state, is_mine, syncing_user_ids, now) if new_state: changes[state.user_id] = new_state return changes.values() -def handle_timeout(state, is_mine, user_to_num_current_syncs, now): +def handle_timeout(state, is_mine, syncing_user_ids, now): """Checks the presence of the user to see if any of the timers have elapsed Args: state (UserPresenceState) is_mine (bool): Whether the user is ours - user_to_num_current_syncs (dict): Mapping of user_id to number of currently - active syncs. + syncing_user_ids (set): Set of user_ids with active syncs. now (int): Current time in ms. Returns: @@ -1002,7 +1095,7 @@ def handle_timeout(state, is_mine, user_to_num_current_syncs, now): # If there are have been no sync for a while (and none ongoing), # set presence to offline - if not user_to_num_current_syncs.get(user_id, 0): + if user_id not in syncing_user_ids: if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: state = state.copy_and_replace( state=PresenceState.OFFLINE, diff --git a/synapse/replication/presence_resource.py b/synapse/replication/presence_resource.py new file mode 100644 index 0000000000..fc18130ab4 --- /dev/null +++ b/synapse/replication/presence_resource.py @@ -0,0 +1,59 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import respond_with_json_bytes, request_handler +from synapse.http.servlet import parse_json_object_from_request + +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET +from twisted.internet import defer + + +class PresenceResource(Resource): + """ + HTTP endpoint for marking users as syncing. + + POST /_synapse/replication/presence HTTP/1.1 + Content-Type: application/json + + { + "process_id": "", + "syncing_users": [""] + } + """ + + def __init__(self, hs): + Resource.__init__(self) # Resource is old-style, so no super() + + self.version_string = hs.version_string + self.clock = hs.get_clock() + self.presence_handler = hs.get_presence_handler() + + def render_POST(self, request): + self._async_render_POST(request) + return NOT_DONE_YET + + @request_handler() + @defer.inlineCallbacks + def _async_render_POST(self, request): + content = parse_json_object_from_request(request) + + process_id = content["process_id"] + syncing_user_ids = content["syncing_users"] + + yield self.presence_handler.update_external_syncs( + process_id, set(syncing_user_ids) + ) + + respond_with_json_bytes(request, 200, "{}") diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 847f212a3d..8c2d487ff4 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -16,6 +16,7 @@ from synapse.http.servlet import parse_integer, parse_string from synapse.http.server import request_handler, finish_request from synapse.replication.pusher_resource import PusherResource +from synapse.replication.presence_resource import PresenceResource from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET @@ -115,6 +116,7 @@ class ReplicationResource(Resource): self.clock = hs.get_clock() self.putChild("remove_pushers", PusherResource(hs)) + self.putChild("syncing_users", PresenceResource(hs)) def render_GET(self, request): self._async_render_GET(request) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 87c795fcfa..b531ba8540 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -264,7 +264,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, user_to_num_current_syncs={}, now=now + state, is_mine=True, syncing_user_ids=set(), now=now ) self.assertIsNotNone(new_state) @@ -282,7 +282,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, user_to_num_current_syncs={}, now=now + state, is_mine=True, syncing_user_ids=set(), now=now ) self.assertIsNotNone(new_state) @@ -300,9 +300,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, user_to_num_current_syncs={ - user_id: 1, - }, now=now + state, is_mine=True, syncing_user_ids=set([user_id]), now=now ) self.assertIsNotNone(new_state) @@ -321,7 +319,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, user_to_num_current_syncs={}, now=now + state, is_mine=True, syncing_user_ids=set(), now=now ) self.assertIsNotNone(new_state) @@ -340,7 +338,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, user_to_num_current_syncs={}, now=now + state, is_mine=True, syncing_user_ids=set(), now=now ) self.assertIsNone(new_state) @@ -358,7 +356,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=False, user_to_num_current_syncs={}, now=now + state, is_mine=False, syncing_user_ids=set(), now=now ) self.assertIsNotNone(new_state) @@ -377,7 +375,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, user_to_num_current_syncs={}, now=now + state, is_mine=True, syncing_user_ids=set(), now=now ) self.assertIsNotNone(new_state) -- cgit 1.5.1 From 661a540dd1de89a3ab3a8f6ca0f780ea7d264176 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 2 Jun 2016 15:20:28 +0100 Subject: Deduplicate presence entries in sync (#818) --- synapse/handlers/sync.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'synapse/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 3b89582d79..5307b62b85 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -637,6 +637,9 @@ class SyncHandler(object): ) presence.extend(states) + # Deduplicate the presence entries so that there's at most one per user + presence = {p["content"]["user_id"]: p for p in presence}.values() + presence = sync_config.filter_collection.filter_presence( presence ) -- cgit 1.5.1 From 56d15a05306896555ded3025f8a808bda04872fa Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 2 Jun 2016 16:28:54 +0100 Subject: Store the typing users as user_id strings. (#819) Rather than storing them as UserID objects. --- synapse/handlers/typing.py | 64 ++++++++++++++++++++++++------------------- tests/handlers/test_typing.py | 4 +-- 2 files changed, 38 insertions(+), 30 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index d46f05f426..3c54307bed 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) # A tiny object useful for storing a user's membership in a room, as a mapping # key -RoomMember = namedtuple("RoomMember", ("room_id", "user")) +RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) class TypingHandler(object): @@ -38,7 +38,7 @@ class TypingHandler(object): self.store = hs.get_datastore() self.server_name = hs.config.server_name self.auth = hs.get_auth() - self.is_mine = hs.is_mine + self.is_mine_id = hs.is_mine_id self.notifier = hs.get_notifier() self.clock = hs.get_clock() @@ -67,20 +67,23 @@ class TypingHandler(object): @defer.inlineCallbacks def started_typing(self, target_user, auth_user, room_id, timeout): - if not self.is_mine(target_user): + target_user_id = target_user.to_string() + auth_user_id = auth_user.to_string() + + if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_joined_room(room_id, target_user.to_string()) + yield self.auth.check_joined_room(room_id, target_user_id) logger.debug( - "%s has started typing in %s", target_user.to_string(), room_id + "%s has started typing in %s", target_user_id, room_id ) until = self.clock.time_msec() + timeout - member = RoomMember(room_id=room_id, user=target_user) + member = RoomMember(room_id=room_id, user_id=target_user_id) was_present = member in self._member_typing_until @@ -104,25 +107,28 @@ class TypingHandler(object): yield self._push_update( room_id=room_id, - user=target_user, + user_id=target_user_id, typing=True, ) @defer.inlineCallbacks def stopped_typing(self, target_user, auth_user, room_id): - if not self.is_mine(target_user): + target_user_id = target_user.to_string() + auth_user_id = auth_user.to_string() + + if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_joined_room(room_id, target_user.to_string()) + yield self.auth.check_joined_room(room_id, target_user_id) logger.debug( - "%s has stopped typing in %s", target_user.to_string(), room_id + "%s has stopped typing in %s", target_user_id, room_id ) - member = RoomMember(room_id=room_id, user=target_user) + member = RoomMember(room_id=room_id, user_id=target_user_id) if member in self._member_typing_timer: self.clock.cancel_call_later(self._member_typing_timer[member]) @@ -132,8 +138,9 @@ class TypingHandler(object): @defer.inlineCallbacks def user_left_room(self, user, room_id): - if self.is_mine(user): - member = RoomMember(room_id=room_id, user=user) + user_id = user.to_string() + if self.is_mine_id(user_id): + member = RoomMember(room_id=room_id, user=user_id) yield self._stopped_typing(member) @defer.inlineCallbacks @@ -144,7 +151,7 @@ class TypingHandler(object): yield self._push_update( room_id=member.room_id, - user=member.user, + user_id=member.user_id, typing=False, ) @@ -156,7 +163,7 @@ class TypingHandler(object): del self._member_typing_timer[member] @defer.inlineCallbacks - def _push_update(self, room_id, user, typing): + def _push_update(self, room_id, user_id, typing): domains = yield self.store.get_joined_hosts_for_room(room_id) deferreds = [] @@ -164,7 +171,7 @@ class TypingHandler(object): if domain == self.server_name: self._push_update_local( room_id=room_id, - user=user, + user_id=user_id, typing=typing ) else: @@ -173,7 +180,7 @@ class TypingHandler(object): edu_type="m.typing", content={ "room_id": room_id, - "user_id": user.to_string(), + "user_id": user_id, "typing": typing, }, )) @@ -183,23 +190,26 @@ class TypingHandler(object): @defer.inlineCallbacks def _recv_edu(self, origin, content): room_id = content["room_id"] - user = UserID.from_string(content["user_id"]) + user_id = content["user_id"] + + # Check that the string is a valid user id + UserID.from_string(user_id) domains = yield self.store.get_joined_hosts_for_room(room_id) if self.server_name in domains: self._push_update_local( room_id=room_id, - user=user, + user_id=user_id, typing=content["typing"] ) - def _push_update_local(self, room_id, user, typing): + def _push_update_local(self, room_id, user_id, typing): room_set = self._room_typing.setdefault(room_id, set()) if typing: - room_set.add(user) + room_set.add(user_id) else: - room_set.discard(user) + room_set.discard(user_id) self._latest_room_serial += 1 self._room_serials[room_id] = self._latest_room_serial @@ -215,9 +225,7 @@ class TypingHandler(object): for room_id, serial in self._room_serials.items(): if last_id < serial and serial <= current_id: typing = self._room_typing[room_id] - typing_bytes = json.dumps([ - u.to_string() for u in typing - ], ensure_ascii=False) + typing_bytes = json.dumps(list(typing), ensure_ascii=False) rows.append((serial, room_id, typing_bytes)) rows.sort() return rows @@ -239,7 +247,7 @@ class TypingNotificationEventSource(object): "type": "m.typing", "room_id": room_id, "content": { - "user_ids": [u.to_string() for u in typing], + "user_ids": list(typing), }, } diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index abb739ae52..ab9899b7d5 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -251,12 +251,12 @@ class TypingNotificationsTestCase(unittest.TestCase): # Gut-wrenching from synapse.handlers.typing import RoomMember - member = RoomMember(self.room_id, self.u_apple) + member = RoomMember(self.room_id, self.u_apple.to_string()) self.handler._member_typing_until[member] = 1002000 self.handler._member_typing_timer[member] = ( self.clock.call_later(1002, lambda: 0) ) - self.handler._room_typing[self.room_id] = set((self.u_apple,)) + self.handler._room_typing[self.room_id] = set((self.u_apple.to_string(),)) self.assertEquals(self.event_source.get_current_key(), 0) -- cgit 1.5.1 From 6a0afa582aa5bf816e082af31ac44e2a8fee28c0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Jun 2016 14:27:07 +0100 Subject: Load push rules in storage layer, so that they get cached --- synapse/handlers/sync.py | 5 ++--- synapse/push/bulk_push_rule_evaluator.py | 28 ----------------------- synapse/push/clientformat.py | 30 ++++++++++++++++++------- synapse/rest/client/v1/push_rule.py | 6 ++--- synapse/storage/push_rule.py | 38 +++++++++++++++++++++++++++++++- 5 files changed, 63 insertions(+), 44 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5307b62b85..be26a491ff 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -198,9 +198,8 @@ class SyncHandler(object): @defer.inlineCallbacks def push_rules_for_user(self, user): user_id = user.to_string() - rawrules = yield self.store.get_push_rules_for_user(user_id) - enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) - rules = format_push_rules_for_user(user, rawrules, enabled_map) + rules = yield self.store.get_push_rules_for_user(user_id) + rules = format_push_rules_for_user(user, rules) defer.returnValue(rules) @defer.inlineCallbacks diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index af5212a5d1..6e42121b1d 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -18,7 +18,6 @@ import ujson as json from twisted.internet import defer -from .baserules import list_with_base_rules from .push_rule_evaluator import PushRuleEvaluatorForEvent from synapse.api.constants import EventTypes, Membership @@ -38,36 +37,9 @@ def decode_rule_json(rule): @defer.inlineCallbacks def _get_rules(room_id, user_ids, store): rules_by_user = yield store.bulk_get_push_rules(user_ids) - rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids) rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} - rules_by_user = { - uid: list_with_base_rules([ - decode_rule_json(rule_list) - for rule_list in rules_by_user.get(uid, []) - ]) - for uid in user_ids - } - - # We apply the rules-enabled map here: bulk_get_push_rules doesn't - # fetch disabled rules, but this won't account for any server default - # rules the user has disabled, so we need to do this too. - for uid in user_ids: - user_enabled_map = rules_enabled_by_user.get(uid) - if not user_enabled_map: - continue - - for i, rule in enumerate(rules_by_user[uid]): - rule_id = rule['rule_id'] - - if rule_id in user_enabled_map: - if rule.get('enabled', True) != bool(user_enabled_map[rule_id]): - # Rules are cached across users. - rule = dict(rule) - rule['enabled'] = bool(user_enabled_map[rule_id]) - rules_by_user[uid][i] = rule - defer.returnValue(rules_by_user) diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index ae9db9ec2f..b3983f7940 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -23,10 +23,7 @@ import copy import simplejson as json -def format_push_rules_for_user(user, rawrules, enabled_map): - """Converts a list of rawrules and a enabled map into nested dictionaries - to match the Matrix client-server format for push rules""" - +def load_rules_for_user(user, rawrules, enabled_map): ruleslist = [] for rawrule in rawrules: rule = dict(rawrule) @@ -35,7 +32,26 @@ def format_push_rules_for_user(user, rawrules, enabled_map): ruleslist.append(rule) # We're going to be mutating this a lot, so do a deep copy - ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) + rules = list(list_with_base_rules(ruleslist)) + + for i, rule in enumerate(rules): + rule_id = rule['rule_id'] + if rule_id in enabled_map: + if rule.get('enabled', True) != bool(enabled_map[rule_id]): + # Rules are cached across users. + rule = dict(rule) + rule['enabled'] = bool(enabled_map[rule_id]) + rules[i] = rule + + return rules + + +def format_push_rules_for_user(user, ruleslist): + """Converts a list of rawrules and a enabled map into nested dictionaries + to match the Matrix client-server format for push rules""" + + # We're going to be mutating this a lot, so do a deep copy + ruleslist = copy.deepcopy(ruleslist) rules = {'global': {}, 'device': {}} @@ -60,9 +76,7 @@ def format_push_rules_for_user(user, rawrules, enabled_map): template_rule = _rule_to_template(r) if template_rule: - if r['rule_id'] in enabled_map: - template_rule['enabled'] = enabled_map[r['rule_id']] - elif 'enabled' in r: + if 'enabled' in r: template_rule['enabled'] = r['enabled'] else: template_rule['enabled'] = True diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 02d837ee6a..6bb4821ec6 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -128,11 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet): # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rawrules = yield self.store.get_push_rules_for_user(user_id) + rules = yield self.store.get_push_rules_for_user(user_id) - enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) - - rules = format_push_rules_for_user(requester.user, rawrules, enabled_map) + rules = format_push_rules_for_user(requester.user, rules) path = request.postpath[1:] diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index ebb97c8474..786d6f6d67 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -15,6 +15,7 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.push.baserules import list_with_base_rules from twisted.internet import defer import logging @@ -23,6 +24,29 @@ import simplejson as json logger = logging.getLogger(__name__) +def _load_rules(rawrules, enabled_map): + ruleslist = [] + for rawrule in rawrules: + rule = dict(rawrule) + rule["conditions"] = json.loads(rawrule["conditions"]) + rule["actions"] = json.loads(rawrule["actions"]) + ruleslist.append(rule) + + # We're going to be mutating this a lot, so do a deep copy + rules = list(list_with_base_rules(ruleslist)) + + for i, rule in enumerate(rules): + rule_id = rule['rule_id'] + if rule_id in enabled_map: + if rule.get('enabled', True) != bool(enabled_map[rule_id]): + # Rules are cached across users. + rule = dict(rule) + rule['enabled'] = bool(enabled_map[rule_id]) + rules[i] = rule + + return rules + + class PushRuleStore(SQLBaseStore): @cachedInlineCallbacks(lru=True) def get_push_rules_for_user(self, user_id): @@ -42,7 +66,11 @@ class PushRuleStore(SQLBaseStore): key=lambda row: (-int(row["priority_class"]), -int(row["priority"])) ) - defer.returnValue(rows) + enabled_map = yield self.get_push_rules_enabled_for_user(user_id) + + rules = _load_rules(rows, enabled_map) + + defer.returnValue(rules) @cachedInlineCallbacks(lru=True) def get_push_rules_enabled_for_user(self, user_id): @@ -85,6 +113,14 @@ class PushRuleStore(SQLBaseStore): for row in rows: results.setdefault(row['user_name'], []).append(row) + + enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) + + for user_id, rules in results.items(): + results[user_id] = _load_rules( + rules, enabled_map_by_user.get(user_id, {}) + ) + defer.returnValue(results) @cachedList(cached_method_name="get_push_rules_enabled_for_user", -- cgit 1.5.1 From 4c04222fa55d35ad3a75c5538ec477046b6c5b30 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Jun 2016 13:02:33 +0100 Subject: Poke notifier on next reactor tick --- synapse/handlers/message.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c41dafdef5..15caf1950a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -26,9 +26,9 @@ from synapse.types import ( UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id ) from synapse.util import unwrapFirstError -from synapse.util.async import concurrently_execute +from synapse.util.async import concurrently_execute, run_on_reactor from synapse.util.caches.snapshot_cache import SnapshotCache -from synapse.util.logcontext import PreserveLoggingContext, preserve_fn +from synapse.util.logcontext import preserve_fn from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -908,13 +908,16 @@ class MessageHandler(BaseHandler): "Failed to get destination from event %s", s.event_id ) - with PreserveLoggingContext(): - # Don't block waiting on waking up all the listeners. + @defer.inlineCallbacks + def _notify(): + yield run_on_reactor() self.notifier.on_new_room_event( event, event_stream_id, max_stream_id, extra_users=extra_users ) + preserve_fn(_notify)() + # If invite, remove room_state from unsigned before sending. event.unsigned.pop("invite_room_state", None) -- cgit 1.5.1 From a7ff5a17702812ae586228396d534a8ed3d88475 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 3 Jun 2016 13:40:55 +0100 Subject: Presence metrics. Change def of small delta --- synapse/handlers/presence.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index fc8538b41e..eb877763ee 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -50,6 +50,9 @@ timers_fired_counter = metrics.register_counter("timers_fired") federation_presence_counter = metrics.register_counter("federation_presence") bump_active_time_counter = metrics.register_counter("bump_active_time") +full_update_presence_counter = metrics.register_counter("full_update_presence") +partial_update_presence_counter = metrics.register_counter("partial_update_presence") + # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them # "currently_active" @@ -974,13 +977,13 @@ class PresenceEventSource(object): user_ids_changed = set() changed = None - if from_key and max_token - from_key < 100: - # For small deltas, its quicker to get all changes and then - # work out if we share a room or they're in our presence list + if from_key: changed = stream_change_cache.get_all_entities_changed(from_key) - # get_all_entities_changed can return None - if changed is not None: + if changed is not None and len(changed) < 100: + # For small deltas, its quicker to get all changes and then + # work out if we share a room or they're in our presence list + partial_update_presence_counter.inc() for other_user_id in changed: if other_user_id in friends: user_ids_changed.add(other_user_id) @@ -992,6 +995,8 @@ class PresenceEventSource(object): else: # Too many possible updates. Find all users we can see and check # if any of them have changed. + full_update_presence_counter.inc() + user_ids_to_check = set() for room_id in room_ids: users = yield self.store.get_users_in_room(room_id) -- cgit 1.5.1 From 4ce84a1acd89a7f61896e92605e5463864848122 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 3 Jun 2016 13:49:16 +0100 Subject: Change metric style --- synapse/handlers/presence.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index eb877763ee..0e19f777b8 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -50,8 +50,7 @@ timers_fired_counter = metrics.register_counter("timers_fired") federation_presence_counter = metrics.register_counter("federation_presence") bump_active_time_counter = metrics.register_counter("bump_active_time") -full_update_presence_counter = metrics.register_counter("full_update_presence") -partial_update_presence_counter = metrics.register_counter("partial_update_presence") +get_updates_counter = metrics.register_counter("get_updates", labels=["type"]) # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them @@ -980,10 +979,10 @@ class PresenceEventSource(object): if from_key: changed = stream_change_cache.get_all_entities_changed(from_key) - if changed is not None and len(changed) < 100: + if changed is not None and len(changed) < 500: # For small deltas, its quicker to get all changes and then # work out if we share a room or they're in our presence list - partial_update_presence_counter.inc() + get_updates_counter.inc("stream") for other_user_id in changed: if other_user_id in friends: user_ids_changed.add(other_user_id) @@ -995,7 +994,7 @@ class PresenceEventSource(object): else: # Too many possible updates. Find all users we can see and check # if any of them have changed. - full_update_presence_counter.inc() + get_updates_counter.inc("full") user_ids_to_check = set() for room_id in room_ids: -- cgit 1.5.1 From ab116bdb0c2d8e295b1473af84c453d212dc07ea Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 3 Jun 2016 14:03:42 +0100 Subject: Fix typo --- synapse/handlers/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 3c54307bed..861b8f7989 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -140,7 +140,7 @@ class TypingHandler(object): def user_left_room(self, user, room_id): user_id = user.to_string() if self.is_mine_id(user_id): - member = RoomMember(room_id=room_id, user=user_id) + member = RoomMember(room_id=room_id, user_id=user_id) yield self._stopped_typing(member) @defer.inlineCallbacks -- cgit 1.5.1 From 377eb480ca66a376e85cf8927f7f9112ed60e8bc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Jun 2016 15:14:21 +0100 Subject: Fire after 30s not 8h --- synapse/handlers/presence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 0e19f777b8..2e772da660 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -183,7 +183,7 @@ class PresenceHandler(object): # The initial delay is to allow disconnected clients a chance to # reconnect before we treat them as offline. self.clock.call_later( - 30 * 1000, + 30, self.clock.looping_call, self._handle_timeouts, 5000, -- cgit 1.5.1 From 96dc600579cd6ef9937b0e007f51aa4da0fc122d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Jun 2016 15:44:41 +0100 Subject: Fix typos --- synapse/handlers/presence.py | 68 +++++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 32 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 2e772da660..94160a5be7 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -283,44 +283,48 @@ class PresenceHandler(object): """Checks the presence of users that have timed out and updates as appropriate. """ + logger.info("Handling presence timeouts") now = self.clock.time_msec() - with Measure(self.clock, "presence_handle_timeouts"): - # Fetch the list of users that *may* have timed out. Things may have - # changed since the timeout was set, so we won't necessarily have to - # take any action. - users_to_check = set(self.wheel_timer.fetch(now)) - - # Check whether the lists of syncing processes from an external - # process have expired. - expired_process_ids = [ - process_id for process_id, last_update - in self.external_process_last_update.items() - if now - last_update > EXTERNAL_PROCESS_EXPIRY - ] - for process_id in expired_process_ids: - users_to_check.update( - self.external_process_to_current_syncs.pop(process_id, ()) - ) - self.external_process_last_update.pop(process_id) + try: + with Measure(self.clock, "presence_handle_timeouts"): + # Fetch the list of users that *may* have timed out. Things may have + # changed since the timeout was set, so we won't necessarily have to + # take any action. + users_to_check = set(self.wheel_timer.fetch(now)) + + # Check whether the lists of syncing processes from an external + # process have expired. + expired_process_ids = [ + process_id for process_id, last_update + in self.external_process_last_updated_ms.items() + if now - last_update > EXTERNAL_PROCESS_EXPIRY + ] + for process_id in expired_process_ids: + users_to_check.update( + self.external_process_last_updated_ms.pop(process_id, ()) + ) + self.external_process_last_update.pop(process_id) - states = [ - self.user_to_current_state.get( - user_id, UserPresenceState.default(user_id) - ) - for user_id in users_to_check - ] + states = [ + self.user_to_current_state.get( + user_id, UserPresenceState.default(user_id) + ) + for user_id in users_to_check + ] - timers_fired_counter.inc_by(len(states)) + timers_fired_counter.inc_by(len(states)) - changes = handle_timeouts( - states, - is_mine_fn=self.is_mine_id, - syncing_users=self.get_syncing_users(), - now=now, - ) + changes = handle_timeouts( + states, + is_mine_fn=self.is_mine_id, + syncing_user_ids=self.get_currently_syncing_users(), + now=now, + ) - preserve_fn(self._update_states)(changes) + preserve_fn(self._update_states)(changes) + except: + logger.exception("Exception in _handle_timeouts loop") @defer.inlineCallbacks def bump_presence_active_time(self, user): -- cgit 1.5.1 From 216a05b3e39e08b0600a39fc111b4d669d06ff7c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Jun 2016 16:00:09 +0100 Subject: .values() returns list of sets --- synapse/handlers/presence.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 94160a5be7..6b70fa3817 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -406,7 +406,8 @@ class PresenceHandler(object): user_id for user_id, count in self.user_to_num_current_syncs.items() if count } - syncing_user_ids.update(self.external_process_to_current_syncs.values()) + for user_ids in self.external_process_to_current_syncs.values(): + syncing_user_ids.update(user_ids) return syncing_user_ids @defer.inlineCallbacks -- cgit 1.5.1 From 0b2158719c43eab87ab7a9448ae1d85008b92b92 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 7 Jun 2016 15:07:11 +0100 Subject: Remove dead code. Loading push rules now happens in the datastore, so we can remove the methods that loaded them outside the datastore. The ``waiting_for_join_list`` in federation handler is populated by anything, so can be removed. The ``_get_members_events_txn`` method isn't called from anywhere so can be removed. --- synapse/handlers/federation.py | 13 ------------- synapse/push/bulk_push_rule_evaluator.py | 8 -------- synapse/push/clientformat.py | 26 -------------------------- synapse/storage/roommember.py | 7 ------- 4 files changed, 54 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 648a505e65..ff83c608e7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -66,10 +66,6 @@ class FederationHandler(BaseHandler): self.hs = hs - self.distributor.observe("user_joined_room", self.user_joined_room) - - self.waiting_for_join_list = {} - self.store = hs.get_datastore() self.replication_layer = hs.get_replication_layer() self.state_handler = hs.get_state_handler() @@ -1091,15 +1087,6 @@ class FederationHandler(BaseHandler): def get_min_depth_for_context(self, context): return self.store.get_min_depth(context) - @log_function - def user_joined_room(self, user, room_id): - waiters = self.waiting_for_join_list.get( - (user.to_string(), room_id), - [] - ) - while waiters: - waiters.pop().callback(None) - @defer.inlineCallbacks @log_function def _handle_new_event(self, origin, event, state=None, auth_events=None, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 6e42121b1d..756e5da513 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -14,7 +14,6 @@ # limitations under the License. import logging -import ujson as json from twisted.internet import defer @@ -27,13 +26,6 @@ from synapse.visibility import filter_events_for_clients logger = logging.getLogger(__name__) -def decode_rule_json(rule): - rule = dict(rule) - rule['conditions'] = json.loads(rule['conditions']) - rule['actions'] = json.loads(rule['actions']) - return rule - - @defer.inlineCallbacks def _get_rules(room_id, user_ids, store): rules_by_user = yield store.bulk_get_push_rules(user_ids) diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index b3983f7940..e0331b2d2d 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -13,37 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.push.baserules import list_with_base_rules - from synapse.push.rulekinds import ( PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP ) import copy -import simplejson as json - - -def load_rules_for_user(user, rawrules, enabled_map): - ruleslist = [] - for rawrule in rawrules: - rule = dict(rawrule) - rule["conditions"] = json.loads(rawrule["conditions"]) - rule["actions"] = json.loads(rawrule["actions"]) - ruleslist.append(rule) - - # We're going to be mutating this a lot, so do a deep copy - rules = list(list_with_base_rules(ruleslist)) - - for i, rule in enumerate(rules): - rule_id = rule['rule_id'] - if rule_id in enabled_map: - if rule.get('enabled', True) != bool(enabled_map[rule_id]): - # Rules are cached across users. - rule = dict(rule) - rule['enabled'] = bool(enabled_map[rule_id]) - rules[i] = rule - - return rules def format_push_rules_for_user(user, ruleslist): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 64b4bd371b..8bd693be72 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -243,13 +243,6 @@ class RoomMemberStore(SQLBaseStore): user_ids = yield self.get_users_in_room(room_id) defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids)) - def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None): - rows = self._get_members_rows_txn( - txn, - room_id, membership, user_id, - ) - return [r["event_id"] for r in rows] - def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): where_clause = "c.room_id = ?" where_values = [room_id] -- cgit 1.5.1 From 1a815fb04f1d17286be27379dd7463936606bd3a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 8 Jun 2016 11:33:30 +0100 Subject: Don't hit DB for noop replications queries --- synapse/handlers/typing.py | 3 +++ synapse/storage/account_data.py | 3 +++ synapse/storage/presence.py | 3 +++ synapse/storage/push_rule.py | 3 +++ synapse/storage/tags.py | 3 +++ 5 files changed, 15 insertions(+) (limited to 'synapse/handlers') diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 861b8f7989..5589296c09 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -221,6 +221,9 @@ class TypingHandler(object): def get_all_typing_updates(self, last_id, current_id): # TODO: Work out a way to do this without scanning the entire state. + if last_id == current_id: + return [] + rows = [] for room_id, serial in self._room_serials.items(): if last_id < serial and serial <= current_id: diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index ec7e8d40d2..3fa226e92d 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -138,6 +138,9 @@ class AccountDataStore(SQLBaseStore): A deferred pair of lists of tuples of stream_id int, user_id string, room_id string, type string, and content string. """ + if last_room_id == current_id and last_global_id == current_id: + return defer.succeed(([], [])) + def get_updated_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, account_data_type, content" diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 3fab57a7e8..d03f7c541e 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -118,6 +118,9 @@ class PresenceStore(SQLBaseStore): ) def get_all_presence_updates(self, last_id, current_id): + if last_id == current_id: + return defer.succeed([]) + def get_all_presence_updates_txn(txn): sql = ( "SELECT stream_id, user_id, state, last_active_ts," diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 786d6f6d67..8183b7f1b0 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -421,6 +421,9 @@ class PushRuleStore(SQLBaseStore): def get_all_push_rule_updates(self, last_id, current_id, limit): """Get all the push rules changes that have happend on the server""" + if last_id == current_id: + return defer.succeed([]) + def get_all_push_rule_updates_txn(txn): sql = ( "SELECT stream_id, event_stream_ordering, user_id, rule_id," diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 9da23f34cb..5a2c1aa59b 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -68,6 +68,9 @@ class TagsStore(SQLBaseStore): A deferred list of tuples of stream_id int, user_id string, room_id string, tag string and content string. """ + if last_id == current_id: + defer.returnValue([]) + def get_all_updated_tags_txn(txn): sql = ( "SELECT stream_id, user_id, room_id" -- cgit 1.5.1 From 81c07a32fd27260b5112dcc87845a9e87fa5db58 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 8 Jun 2016 15:43:37 +0100 Subject: Pull full state for each room all at once --- synapse/handlers/room.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9fd34588dd..ae44c7a556 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -20,7 +20,7 @@ from ._base import BaseHandler from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken from synapse.api.constants import ( - EventTypes, JoinRules, RoomCreationPreset, + EventTypes, JoinRules, RoomCreationPreset, Membership, ) from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.util import stringutils @@ -367,14 +367,10 @@ class RoomListHandler(BaseHandler): @defer.inlineCallbacks def handle_room(room_id): - # We pull each bit of state out indvidually to avoid pulling the - # full state into memory. Due to how the caching works this should - # be fairly quick, even if not originally in the cache. - def get_state(etype, state_key): - return self.state_handler.get_current_state(room_id, etype, state_key) + current_state = yield self.state_handler.get_current_state(room_id) # Double check that this is actually a public room. - join_rules_event = yield get_state(EventTypes.JoinRules, "") + join_rules_event = current_state.get((EventTypes.JoinRules, "")) if join_rules_event: join_rule = join_rules_event.content.get("join_rule", None) if join_rule and join_rule != JoinRules.PUBLIC: @@ -382,47 +378,51 @@ class RoomListHandler(BaseHandler): result = {"room_id": room_id} - joined_users = yield self.store.get_users_in_room(room_id) - if len(joined_users) == 0: + num_joined_users = len([ + 1 for _, event in current_state.items() + if event.type == EventTypes.Member + and event.membership == Membership.JOIN + ]) + if num_joined_users == 0: return - result["num_joined_members"] = len(joined_users) + result["num_joined_members"] = num_joined_users aliases = yield self.store.get_aliases_for_room(room_id) if aliases: result["aliases"] = aliases - name_event = yield get_state(EventTypes.Name, "") + name_event = yield current_state.get((EventTypes.Name, "")) if name_event: name = name_event.content.get("name", None) if name: result["name"] = name - topic_event = yield get_state(EventTypes.Topic, "") + topic_event = current_state.get((EventTypes.Topic, "")) if topic_event: topic = topic_event.content.get("topic", None) if topic: result["topic"] = topic - canonical_event = yield get_state(EventTypes.CanonicalAlias, "") + canonical_event = current_state.get((EventTypes.CanonicalAlias, "")) if canonical_event: canonical_alias = canonical_event.content.get("alias", None) if canonical_alias: result["canonical_alias"] = canonical_alias - visibility_event = yield get_state(EventTypes.RoomHistoryVisibility, "") + visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, "")) visibility = None if visibility_event: visibility = visibility_event.content.get("history_visibility", None) result["world_readable"] = visibility == "world_readable" - guest_event = yield get_state(EventTypes.GuestAccess, "") + guest_event = current_state.get((EventTypes.GuestAccess, "")) guest = None if guest_event: guest = guest_event.content.get("guest_access", None) result["guest_can_join"] = guest == "can_join" - avatar_event = yield get_state("m.room.avatar", "") + avatar_event = current_state.get(("m.room.avatar", "")) if avatar_event: avatar_url = avatar_event.content.get("url", None) if avatar_url: -- cgit 1.5.1 From 6e7dc7c7dde377794c23d5db6f25ffacfb08e82a Mon Sep 17 00:00:00 2001 From: Negar Fazeli Date: Wed, 8 Jun 2016 19:16:46 +0200 Subject: Fix a bug caused by a change in auth_handler function Fix the relevant unit test cases --- synapse/handlers/register.py | 4 ++-- tests/handlers/test_register.py | 9 +++------ 2 files changed, 5 insertions(+), 8 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index bbc07b045e..e0aaefe7be 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -388,8 +388,8 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - auth_handler = self.hs.get_handlers().auth_handler - token = auth_handler.generate_short_term_login_token(user_id, duration_seconds) + token = self.auth_handler().generate_short_term_login_token( + user_id, duration_seconds) if need_register: yield self.store.register( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 9d5c653b45..69a5e5b1d4 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -41,14 +41,15 @@ class RegistrationTestCase(unittest.TestCase): handlers=None, http_client=None, expire_access_token=True) + self.auth_handler = Mock( + generate_short_term_login_token=Mock(return_value='secret')) self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler self.hs.get_handlers().profile_handler = Mock() self.mock_handler = Mock(spec=[ "generate_short_term_login_token", ]) - - self.hs.get_handlers().auth_handler = self.mock_handler + self.hs.get_auth_handler = Mock(return_value=self.auth_handler) @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): @@ -56,8 +57,6 @@ class RegistrationTestCase(unittest.TestCase): local_part = "someone" display_name = "someone" user_id = "@someone:test" - mock_token = self.mock_handler.generate_short_term_login_token - mock_token.return_value = 'secret' result_user_id, result_token = yield self.handler.get_or_create_user( local_part, display_name, duration_ms) self.assertEquals(result_user_id, user_id) @@ -75,8 +74,6 @@ class RegistrationTestCase(unittest.TestCase): local_part = "frank" display_name = "Frank" user_id = "@frank:test" - mock_token = self.mock_handler.generate_short_term_login_token - mock_token.return_value = 'secret' result_user_id, result_token = yield self.handler.get_or_create_user( local_part, display_name, duration_ms) self.assertEquals(result_user_id, user_id) -- cgit 1.5.1 From b31c49d6760b4cdeefc8e0b43d6639be4576e249 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 15 Jun 2016 10:58:07 +0100 Subject: Correctly mark backfilled events as backfilled --- synapse/handlers/federation.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ff83c608e7..c2df43e2f6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -345,19 +345,21 @@ class FederationHandler(BaseHandler): ) missing_auth = required_auth - set(auth_events) - results = yield defer.gatherResults( - [ - self.replication_layer.get_pdu( - [dest], - event_id, - outlier=True, - timeout=10000, - ) - for event_id in missing_auth - ], - consumeErrors=True - ).addErrback(unwrapFirstError) - auth_events.update({a.event_id: a for a in results}) + if missing_auth: + logger.info("Missing auth for backfill: %r", missing_auth) + results = yield defer.gatherResults( + [ + self.replication_layer.get_pdu( + [dest], + event_id, + outlier=True, + timeout=10000, + ) + for event_id in missing_auth + ], + consumeErrors=True + ).addErrback(unwrapFirstError) + auth_events.update({a.event_id: a for a in results}) ev_infos = [] for a in auth_events.values(): @@ -399,7 +401,7 @@ class FederationHandler(BaseHandler): # previous to work out the state. # TODO: We can probably do something more clever here. yield self._handle_new_event( - dest, event + dest, event, backfilled=True, ) defer.returnValue(events) -- cgit 1.5.1 From ed5f43a55accc8502a60b721871b208db704de3e Mon Sep 17 00:00:00 2001 From: Salvatore LaMendola Date: Thu, 16 Jun 2016 00:43:42 -0400 Subject: Fix TypeError in call to bcrypt.hashpw - At the very least, this TypeError caused logins to fail on my own running instance of Synapse, and the simple (explicit) UTF-8 conversion resolved login errors for me. Signed-off-by: Salvatore LaMendola --- synapse/handlers/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 200793b5ed..b38f81e999 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -626,6 +626,6 @@ class AuthHandler(BaseHandler): Whether self.hash(password) == stored_hash (bool). """ if stored_hash: - return bcrypt.hashpw(password, stored_hash) == stored_hash + return bcrypt.hashpw(password, stored_hash.encode('utf-8')) == stored_hash else: return False -- cgit 1.5.1 From 2884712ca733f45d32468ecf2ede7a1518e85be4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 Jun 2016 14:47:33 +0100 Subject: Only re-sign our own events --- synapse/federation/federation_server.py | 15 +++++++++------ synapse/handlers/federation.py | 15 +++++++++------ 2 files changed, 18 insertions(+), 12 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index fe92457ba1..2a589524a4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -193,13 +193,16 @@ class FederationServer(FederationBase): ) for event in auth_chain: - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] + # We sign these again because there was a bug where we + # incorrectly signed things the first time round + if self.hs.is_mine_id(event.event_id): + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) ) - ) else: raise NotImplementedError("Specify an event") diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c2df43e2f6..6c0bc7eafa 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1018,13 +1018,16 @@ class FederationHandler(BaseHandler): res = results.values() for event in res: - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] + # We sign these again because there was a bug where we + # incorrectly signed things the first time round + if self.hs.is_mine_id(event.event_id): + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) ) - ) defer.returnValue(res) else: -- cgit 1.5.1 From 9f1800fba852314332d7e682484e456d28838619 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 17 Jun 2016 19:14:16 +0100 Subject: Remove registered_users from the distributor. The only place that was observed was to set the profile. I've made it so that the profile is set within store.register in the same transaction that creates the user. This required some slight changes to the registration code for upgrading guest users, since it previously relied on the distributor swallowing errors if the profile already existed. --- synapse/handlers/profile.py | 7 ------- synapse/handlers/register.py | 23 ++++++++++------------- synapse/storage/profile.py | 6 ------ synapse/storage/registration.py | 17 ++++++++++++++--- synapse/util/distributor.py | 4 ---- 5 files changed, 24 insertions(+), 33 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index e37409170d..711a6a567f 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -36,13 +36,6 @@ class ProfileHandler(BaseHandler): "profile", self.on_profile_query ) - distributor = hs.get_distributor() - - distributor.observe("registered_user", self.registered_user) - - def registered_user(self, user): - return self.store.create_profile(user.localpart) - @defer.inlineCallbacks def get_displayname(self, target_user): if self.hs.is_mine(target_user): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index e0aaefe7be..4fb12915dc 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -23,7 +23,6 @@ from synapse.api.errors import ( from ._base import BaseHandler from synapse.util.async import run_on_reactor from synapse.http.client import CaptchaServerHttpClient -from synapse.util.distributor import registered_user import logging import urllib @@ -37,8 +36,6 @@ class RegistrationHandler(BaseHandler): super(RegistrationHandler, self).__init__(hs) self.auth = hs.get_auth() - self.distributor = hs.get_distributor() - self.distributor.declare("registered_user") self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None @@ -140,9 +137,10 @@ class RegistrationHandler(BaseHandler): password_hash=password_hash, was_guest=was_guest, make_guest=make_guest, + create_profile_with_localpart=( + None if was_guest else user.localpart + ), ) - - yield registered_user(self.distributor, user) else: # autogen a sequential user ID attempts = 0 @@ -160,7 +158,8 @@ class RegistrationHandler(BaseHandler): user_id=user_id, token=token, password_hash=password_hash, - make_guest=make_guest + make_guest=make_guest, + create_profile_with_localpart=user.localpart, ) except SynapseError: # if user id is taken, just generate another @@ -168,7 +167,6 @@ class RegistrationHandler(BaseHandler): user_id = None token = None attempts += 1 - yield registered_user(self.distributor, user) # We used to generate default identicons here, but nowadays # we want clients to generate their own as part of their branding @@ -201,8 +199,8 @@ class RegistrationHandler(BaseHandler): token=token, password_hash="", appservice_id=service_id, + create_profile_with_localpart=user.localpart, ) - yield registered_user(self.distributor, user) defer.returnValue((user_id, token)) @defer.inlineCallbacks @@ -248,9 +246,9 @@ class RegistrationHandler(BaseHandler): yield self.store.register( user_id=user_id, token=token, - password_hash=None + password_hash=None, + create_profile_with_localpart=user.localpart, ) - yield registered_user(self.distributor, user) except Exception as e: yield self.store.add_access_token_to_user(user_id, token) # Ignore Registration errors @@ -395,10 +393,9 @@ class RegistrationHandler(BaseHandler): yield self.store.register( user_id=user_id, token=token, - password_hash=None + password_hash=None, + create_profile_with_localpart=user.localpart, ) - - yield registered_user(self.distributor, user) else: yield self.store.user_delete_access_tokens(user_id=user_id) yield self.store.add_access_token_to_user(user_id=user_id, token=token) diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 26a40905ae..c3c3f9ffd8 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -17,12 +17,6 @@ from ._base import SQLBaseStore class ProfileStore(SQLBaseStore): - def create_profile(self, user_localpart): - return self._simple_insert( - table="profiles", - values={"user_id": user_localpart}, - desc="create_profile", - ) def get_profile_displayname(self, user_localpart): return self._simple_select_one_onecol( diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index bda84a744a..3de9e0f709 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -76,7 +76,8 @@ class RegistrationStore(SQLBaseStore): @defer.inlineCallbacks def register(self, user_id, token, password_hash, - was_guest=False, make_guest=False, appservice_id=None): + was_guest=False, make_guest=False, appservice_id=None, + create_profile_with_localpart=None): """Attempts to register an account. Args: @@ -88,6 +89,8 @@ class RegistrationStore(SQLBaseStore): make_guest (boolean): True if the the new user should be guest, false to add a regular user account. appservice_id (str): The ID of the appservice registering the user. + create_profile_with_localpart (str): Optionally create a profile for + the given localpart. Raises: StoreError if the user_id could not be registered. """ @@ -99,7 +102,8 @@ class RegistrationStore(SQLBaseStore): password_hash, was_guest, make_guest, - appservice_id + appservice_id, + create_profile_with_localpart, ) self.get_user_by_id.invalidate((user_id,)) self.is_guest.invalidate((user_id,)) @@ -112,7 +116,8 @@ class RegistrationStore(SQLBaseStore): password_hash, was_guest, make_guest, - appservice_id + appservice_id, + create_profile_with_localpart, ): now = int(self.clock.time()) @@ -157,6 +162,12 @@ class RegistrationStore(SQLBaseStore): (next_id, user_id, token,) ) + if create_profile_with_localpart: + txn.execute( + "INSERT INTO profiles(user_id) VALUES (?)", + (create_profile_with_localpart,) + ) + @cached() def get_user_by_id(self, user_id): return self._simple_select_one( diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index d7cccc06b1..e68f94ce77 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -27,10 +27,6 @@ import logging logger = logging.getLogger(__name__) -def registered_user(distributor, user): - return distributor.fire("registered_user", user) - - def user_left_room(distributor, user, room_id): return preserve_context_over_fn( distributor.fire, -- cgit 1.5.1 From 0c13d45522c5c8c0b68322498a220969eb894ad5 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 17 Jun 2016 19:18:53 +0100 Subject: Add a comment on why we don't create a profile for upgrading users --- synapse/handlers/register.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/handlers') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 4fb12915dc..0b7517221d 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -138,6 +138,7 @@ class RegistrationHandler(BaseHandler): was_guest=was_guest, make_guest=make_guest, create_profile_with_localpart=( + # If the user was a guest then they already have a profile None if was_guest else user.localpart ), ) -- cgit 1.5.1 From 0a32208e5dde4980a5962f17e9b27f2e28e1f3f1 Mon Sep 17 00:00:00 2001 From: Martin Weinelt Date: Mon, 6 Jun 2016 02:05:57 +0200 Subject: Rework ldap integration with ldap3 Use the pure-python ldap3 library, which eliminates the need for a system dependency. Offer both a `search` and `simple_bind` mode, for more sophisticated ldap scenarios. - `search` tries to find a matching DN within the `user_base` while employing the `user_filter`, then tries the bind when a single matching DN was found. - `simple_bind` tries the bind against a specific DN by combining the localpart and `user_base` Offer support for STARTTLS on a plain connection. The configuration was changed to reflect these new possibilities. Signed-off-by: Martin Weinelt --- synapse/config/ldap.py | 102 +++++++++++++++------ synapse/handlers/auth.py | 203 ++++++++++++++++++++++++++++++++++------- synapse/python_dependencies.py | 3 + tests/utils.py | 1 + 4 files changed, 249 insertions(+), 60 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/config/ldap.py b/synapse/config/ldap.py index 9c14593a99..d83c2230be 100644 --- a/synapse/config/ldap.py +++ b/synapse/config/ldap.py @@ -13,40 +13,88 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +from ._base import Config, ConfigError + + +MISSING_LDAP3 = ( + "Missing ldap3 library. This is required for LDAP Authentication." +) + + +class LDAPMode(object): + SIMPLE = "simple", + SEARCH = "search", + + LIST = (SIMPLE, SEARCH) class LDAPConfig(Config): def read_config(self, config): - ldap_config = config.get("ldap_config", None) - if ldap_config: - self.ldap_enabled = ldap_config.get("enabled", False) - self.ldap_server = ldap_config["server"] - self.ldap_port = ldap_config["port"] - self.ldap_tls = ldap_config.get("tls", False) - self.ldap_search_base = ldap_config["search_base"] - self.ldap_search_property = ldap_config["search_property"] - self.ldap_email_property = ldap_config["email_property"] - self.ldap_full_name_property = ldap_config["full_name_property"] - else: - self.ldap_enabled = False - self.ldap_server = None - self.ldap_port = None - self.ldap_tls = False - self.ldap_search_base = None - self.ldap_search_property = None - self.ldap_email_property = None - self.ldap_full_name_property = None + ldap_config = config.get("ldap_config", {}) + + self.ldap_enabled = ldap_config.get("enabled", False) + + if self.ldap_enabled: + # verify dependencies are available + try: + import ldap3 + ldap3 # to stop unused lint + except ImportError: + raise ConfigError(MISSING_LDAP3) + + self.ldap_mode = LDAPMode.SIMPLE + + # verify config sanity + self.require_keys(ldap_config, [ + "uri", + "base", + "attributes", + ]) + + self.ldap_uri = ldap_config["uri"] + self.ldap_start_tls = ldap_config.get("start_tls", False) + self.ldap_base = ldap_config["base"] + self.ldap_attributes = ldap_config["attributes"] + + if "bind_dn" in ldap_config: + self.ldap_mode = LDAPMode.SEARCH + self.require_keys(ldap_config, [ + "bind_dn", + "bind_password", + ]) + + self.ldap_bind_dn = ldap_config["bind_dn"] + self.ldap_bind_password = ldap_config["bind_password"] + self.ldap_filter = ldap_config.get("filter", None) + + # verify attribute lookup + self.require_keys(ldap_config['attributes'], [ + "uid", + "name", + "mail", + ]) + + def require_keys(self, config, required): + missing = [key for key in required if key not in config] + if missing: + raise ConfigError( + "LDAP enabled but missing required config values: {}".format( + ", ".join(missing) + ) + ) def default_config(self, **kwargs): return """\ # ldap_config: # enabled: true - # server: "ldap://localhost" - # port: 389 - # tls: false - # search_base: "ou=Users,dc=example,dc=com" - # search_property: "cn" - # email_property: "email" - # full_name_property: "givenName" + # uri: "ldap://ldap.example.com:389" + # start_tls: true + # base: "ou=users,dc=example,dc=com" + # attributes: + # uid: "cn" + # mail: "email" + # name: "givenName" + # #bind_dn: + # #bind_password: + # #filter: "(objectClass=posixAccount)" """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index b38f81e999..968095c141 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -20,6 +20,7 @@ from synapse.api.constants import LoginType from synapse.types import UserID from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.util.async import run_on_reactor +from synapse.config.ldap import LDAPMode from twisted.web.client import PartialDownloadError @@ -28,6 +29,12 @@ import bcrypt import pymacaroons import simplejson +try: + import ldap3 +except ImportError: + ldap3 = None + pass + import synapse.util.stringutils as stringutils @@ -50,17 +57,20 @@ class AuthHandler(BaseHandler): self.INVALID_TOKEN_HTTP_STATUS = 401 self.ldap_enabled = hs.config.ldap_enabled - self.ldap_server = hs.config.ldap_server - self.ldap_port = hs.config.ldap_port - self.ldap_tls = hs.config.ldap_tls - self.ldap_search_base = hs.config.ldap_search_base - self.ldap_search_property = hs.config.ldap_search_property - self.ldap_email_property = hs.config.ldap_email_property - self.ldap_full_name_property = hs.config.ldap_full_name_property - - if self.ldap_enabled is True: - import ldap - logger.info("Import ldap version: %s", ldap.__version__) + if self.ldap_enabled: + if not ldap3: + raise RuntimeError( + 'Missing ldap3 library. This is required for LDAP Authentication.' + ) + self.ldap_mode = hs.config.ldap_mode + self.ldap_uri = hs.config.ldap_uri + self.ldap_start_tls = hs.config.ldap_start_tls + self.ldap_base = hs.config.ldap_base + self.ldap_filter = hs.config.ldap_filter + self.ldap_attributes = hs.config.ldap_attributes + if self.ldap_mode == LDAPMode.SEARCH: + self.ldap_bind_dn = hs.config.ldap_bind_dn + self.ldap_bind_password = hs.config.ldap_bind_password self.hs = hs # FIXME better possibility to access registrationHandler later? @@ -452,40 +462,167 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def _check_ldap_password(self, user_id, password): - if not self.ldap_enabled: - logger.debug("LDAP not configured") + """ Attempt to authenticate a user against an LDAP Server + and register an account if none exists. + + Returns: + True if authentication against LDAP was successful + """ + + if not ldap3 or not self.ldap_enabled: defer.returnValue(False) - import ldap + if self.ldap_mode not in LDAPMode.LIST: + raise RuntimeError( + 'Invalid ldap mode specified: {mode}'.format( + mode=self.ldap_mode + ) + ) - logger.info("Authenticating %s with LDAP" % user_id) try: - ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port) - logger.debug("Connecting LDAP server at %s" % ldap_url) - l = ldap.initialize(ldap_url) - if self.ldap_tls: - logger.debug("Initiating TLS") - self._connection.start_tls_s() + server = ldap3.Server(self.ldap_uri) + logger.debug( + "Attempting ldap connection with %s", + self.ldap_uri + ) - local_name = UserID.from_string(user_id).localpart + localpart = UserID.from_string(user_id).localpart + if self.ldap_mode == LDAPMode.SIMPLE: + # bind with the the local users ldap credentials + bind_dn = "{prop}={value},{base}".format( + prop=self.ldap_attributes['uid'], + value=localpart, + base=self.ldap_base + ) + conn = ldap3.Connection(server, bind_dn, password) + logger.debug( + "Established ldap connection in simple mode: %s", + conn + ) - dn = "%s=%s, %s" % ( - self.ldap_search_property, - local_name, - self.ldap_search_base) - logger.debug("DN for LDAP authentication: %s" % dn) + if self.ldap_start_tls: + conn.start_tls() + logger.debug( + "Upgraded ldap connection in simple mode through StartTLS: %s", + conn + ) + + conn.bind() + + elif self.ldap_mode == LDAPMode.SEARCH: + # connect with preconfigured credentials and search for local user + conn = ldap3.Connection( + server, + self.ldap_bind_dn, + self.ldap_bind_password + ) + logger.debug( + "Established ldap connection in search mode: %s", + conn + ) + + if self.ldap_start_tls: + conn.start_tls() + logger.debug( + "Upgraded ldap connection in search mode through StartTLS: %s", + conn + ) - l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8')) + conn.bind() + # find matching dn + query = "({prop}={value})".format( + prop=self.ldap_attributes['uid'], + value=localpart + ) + if self.ldap_filter: + query = "(&{query}{filter})".format( + query=query, + filter=self.ldap_filter + ) + logger.debug("ldap search filter: %s", query) + result = conn.search(self.ldap_base, query) + + if result and len(conn.response) == 1: + # found exactly one result + user_dn = conn.response[0]['dn'] + logger.debug('ldap search found dn: %s', user_dn) + + # unbind and reconnect, rebind with found dn + conn.unbind() + conn = ldap3.Connection( + server, + user_dn, + password, + auto_bind=True + ) + else: + # found 0 or > 1 results, abort! + logger.warn( + "ldap search returned unexpected (%d!=1) amount of results", + len(conn.response) + ) + defer.returnValue(False) + + logger.info( + "User authenticated against ldap server: %s", + conn + ) + + # check for existing account, if none exists, create one if not (yield self.does_user_exist(user_id)): - handler = self.hs.get_handlers().registration_handler - user_id, access_token = ( - yield handler.register(localpart=local_name) + # query user metadata for account creation + query = "({prop}={value})".format( + prop=self.ldap_attributes['uid'], + value=localpart + ) + + if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter: + query = "(&{filter}{user_filter})".format( + filter=query, + user_filter=self.ldap_filter + ) + logger.debug("ldap registration filter: %s", query) + + result = conn.search( + search_base=self.ldap_base, + search_filter=query, + attributes=[ + self.ldap_attributes['name'], + self.ldap_attributes['mail'] + ] ) + if len(conn.response) == 1: + attrs = conn.response[0]['attributes'] + mail = attrs[self.ldap_attributes['mail']][0] + name = attrs[self.ldap_attributes['name']][0] + + # create account + registration_handler = self.hs.get_handlers().registration_handler + user_id, access_token = ( + yield registration_handler.register(localpart=localpart) + ) + + # TODO: bind email, set displayname with data from ldap directory + + logger.info( + "ldap registration successful: %d: %s (%s, %)", + user_id, + localpart, + name, + mail + ) + else: + logger.warn( + "ldap registration failed: unexpected (%d!=1) amount of results", + len(result) + ) + defer.returnValue(False) + defer.returnValue(True) - except ldap.LDAPError, e: - logger.warn("LDAP error: %s", e) + except ldap3.core.exceptions.LDAPException as e: + logger.warn("Error during ldap authentication: %s", e) defer.returnValue(False) @defer.inlineCallbacks diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index e0a7a19777..e024cec0a2 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -48,6 +48,9 @@ CONDITIONAL_REQUIREMENTS = { "Jinja2>=2.8": ["Jinja2>=2.8"], "bleach>=1.4.2": ["bleach>=1.4.2"], }, + "ldap": { + "ldap3>=1.0": ["ldap3>=1.0"], + }, } diff --git a/tests/utils.py b/tests/utils.py index 6e41ae1ff6..ed547bc39b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -56,6 +56,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.use_frozen_dicts = True config.database_config = {"name": "sqlite3"} + config.ldap_enabled = False if "clock" not in kargs: kargs["clock"] = MockClock() -- cgit 1.5.1 From be8be535f73e51a29cfa30f1eac266a7a08b695b Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 30 Jun 2016 17:51:28 +0100 Subject: requestToken update Don't send requestToken request to untrusted ID servers Also correct the THREEPID_IN_USE error to add the M_ prefix. This is a backwards incomaptible change, but the only thing using this is the angular client which is now unmaintained, so it's probably better to just do this now. --- synapse/api/errors.py | 3 ++- synapse/handlers/identity.py | 41 +++++++++++++++++++++++++++-------------- 2 files changed, 29 insertions(+), 15 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b106fbed6d..b219b46a4b 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -42,8 +42,9 @@ class Codes(object): TOO_LARGE = "M_TOO_LARGE" EXCLUSIVE = "M_EXCLUSIVE" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" - THREEPID_IN_USE = "THREEPID_IN_USE" + THREEPID_IN_USE = "M_THREEPID_IN_USE" INVALID_USERNAME = "M_INVALID_USERNAME" + SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" class CodeMessageException(RuntimeError): diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 656ce124f9..559e5d5a71 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -21,7 +21,7 @@ from synapse.api.errors import ( ) from ._base import BaseHandler from synapse.util.async import run_on_reactor -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, Codes import json import logging @@ -41,6 +41,20 @@ class IdentityHandler(BaseHandler): hs.config.use_insecure_ssl_client_just_for_testing_do_not_use ) + def _should_trust_id_server(self, id_server): + if id_server not in self.trusted_id_servers: + if self.trust_any_id_server_just_for_testing_do_not_use: + logger.warn( + "Trusting untrustworthy ID server %r even though it isn't" + " in the trusted id list for testing because" + " 'use_insecure_ssl_client_just_for_testing_do_not_use'" + " is set in the config", + id_server, + ) + else: + return False + return True + @defer.inlineCallbacks def threepid_from_creds(self, creds): yield run_on_reactor() @@ -59,19 +73,12 @@ class IdentityHandler(BaseHandler): else: raise SynapseError(400, "No client_secret in creds") - if id_server not in self.trusted_id_servers: - if self.trust_any_id_server_just_for_testing_do_not_use: - logger.warn( - "Trusting untrustworthy ID server %r even though it isn't" - " in the trusted id list for testing because" - " 'use_insecure_ssl_client_just_for_testing_do_not_use'" - " is set in the config", - id_server, - ) - else: - logger.warn('%s is not a trusted ID server: rejecting 3pid ' + - 'credentials', id_server) - defer.returnValue(None) + if not self._should_trust_id_server(id_server): + logger.warn( + '%s is not a trusted ID server: rejecting 3pid ' + + 'credentials', id_server + ) + defer.returnValue(None) data = {} try: @@ -129,6 +136,12 @@ class IdentityHandler(BaseHandler): def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs): yield run_on_reactor() + if not self._should_trust_id_server(id_server): + raise SynapseError( + 400, "Untrusted ID server '%s'" % id_server, + Codes.SERVER_NOT_TRUSTED + ) + params = { 'email': email, 'client_secret': client_secret, -- cgit 1.5.1 From fc8007dbec40212ae85285aea600111ce2d06912 Mon Sep 17 00:00:00 2001 From: Kent Shikama Date: Sun, 3 Jul 2016 15:08:15 +0900 Subject: Optionally include password hash in createUser endpoint Signed-off-by: Kent Shikama --- synapse/handlers/register.py | 4 ++-- synapse/rest/client/v1/register.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0b7517221d..e255f2da81 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -358,7 +358,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, localpart, displayname, duration_seconds): + def get_or_create_user(self, localpart, displayname, duration_seconds, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -394,7 +394,7 @@ class RegistrationHandler(BaseHandler): yield self.store.register( user_id=user_id, token=token, - password_hash=None, + password_hash=password_hash, create_profile_with_localpart=user.localpart, ) else: diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index e3f4fbb0bb..ef56d1e90f 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -410,12 +410,14 @@ class CreateUserRestServlet(ClientV1RestServlet): raise SynapseError(400, "Failed to parse 'duration_seconds'") if duration_seconds > self.direct_user_creation_max_duration: duration_seconds = self.direct_user_creation_max_duration + password_hash = user_json["password_hash"].encode("utf-8") if user_json["password_hash"] else None handler = self.handlers.registration_handler user_id, token = yield handler.get_or_create_user( localpart=localpart, displayname=displayname, - duration_seconds=duration_seconds + duration_seconds=duration_seconds, + password_hash=password_hash ) defer.returnValue({ -- cgit 1.5.1 From bb069079bbd0ce761403416ed4f77051352ed347 Mon Sep 17 00:00:00 2001 From: Kent Shikama Date: Mon, 4 Jul 2016 22:07:11 +0900 Subject: Fix style violations Signed-off-by: Kent Shikama --- synapse/handlers/register.py | 3 ++- synapse/rest/client/v1/register.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index e255f2da81..88c82ba7d0 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -358,7 +358,8 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, localpart, displayname, duration_seconds, password_hash=None): + def get_or_create_user(self, localpart, displayname, duration_seconds, + password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index a923d5a198..d791d5e07e 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -410,7 +410,8 @@ class CreateUserRestServlet(ClientV1RestServlet): raise SynapseError(400, "Failed to parse 'duration_seconds'") if duration_seconds > self.direct_user_creation_max_duration: duration_seconds = self.direct_user_creation_max_duration - password_hash = user_json["password_hash"].encode("utf-8") if user_json.get("password_hash") else None + password_hash = user_json["password_hash"].encode("utf-8") \ + if user_json.get("password_hash") else None handler = self.handlers.registration_handler user_id, token = yield handler.get_or_create_user( -- cgit 1.5.1 From 8bdaf5f7afaee98a8cf25d2fb170fe4b2aa97f3d Mon Sep 17 00:00:00 2001 From: Kent Shikama Date: Tue, 5 Jul 2016 02:13:52 +0900 Subject: Add pepper to password hashing Signed-off-by: Kent Shikama --- synapse/config/password.py | 6 +++++- synapse/handlers/auth.py | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/config/password.py b/synapse/config/password.py index dec801ef41..ea822f2bb5 100644 --- a/synapse/config/password.py +++ b/synapse/config/password.py @@ -23,10 +23,14 @@ class PasswordConfig(Config): def read_config(self, config): password_config = config.get("password_config", {}) self.password_enabled = password_config.get("enabled", True) + self.pepper = password_config.get("pepper", "") def default_config(self, config_dir_path, server_name, **kwargs): return """ # Enable password for login. password_config: enabled: true - """ + # Uncomment for extra security for your passwords. + # DO NOT CHANGE THIS AFTER INITIAL SETUP! + #pepper: "HR32t0xZcQnzn3O0ZkEVuetdFvH1W6TeEPw6JjH0Cl+qflVOseGyFJlJR7ACLnywjN9" + """ \ No newline at end of file diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 968095c141..fd5fadf73d 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -750,7 +750,7 @@ class AuthHandler(BaseHandler): Returns: Hashed password (str). """ - return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds)) + return bcrypt.hashpw(password + self.hs.config.password_config.pepper, bcrypt.gensalt(self.bcrypt_rounds)) def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. @@ -763,6 +763,7 @@ class AuthHandler(BaseHandler): Whether self.hash(password) == stored_hash (bool). """ if stored_hash: - return bcrypt.hashpw(password, stored_hash.encode('utf-8')) == stored_hash + return bcrypt.hashpw(password + self.hs.config.password_config.pepper, + stored_hash.encode('utf-8')) == stored_hash else: return False -- cgit 1.5.1 From 2d21d43c34751cffb5f324bd58ceff060f65f679 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Jul 2016 10:28:51 +0100 Subject: Add purge_history API --- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 13 +++++++++++++ synapse/rest/client/v1/admin.py | 18 ++++++++++++++++++ synapse/storage/events.py | 6 ++++++ 4 files changed, 38 insertions(+), 1 deletion(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 6c0bc7eafa..351b218247 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1413,7 +1413,7 @@ class FederationHandler(BaseHandler): local_view = dict(auth_events) remote_view = dict(auth_events) remote_view.update({ - (d.type, d.state_key): d for d in different_events + (d.type, d.state_key): d for d in different_events if d }) new_state, prev_state = self.state_handler.resolve_events( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 15caf1950a..878809d50d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -50,6 +50,19 @@ class MessageHandler(BaseHandler): self.validator = EventValidator() self.snapshot_cache = SnapshotCache() + @defer.inlineCallbacks + def purge_history(self, room_id, event_id): + event = yield self.store.get_event(event_id) + + if event.room_id != room_id: + raise SynapseError(400, "Event is for wrong room.") + + depth = event.depth + + # TODO: Lock. + + yield self.store.delete_old_state(room_id, depth) + @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, as_client_event=True): diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index e54c472e08..71537a7d0b 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -77,6 +77,24 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) +class PurgeHistoryRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns( + "/admin/purge_history/(?P[^/]*)/(?P[^/]*)" + ) + + @defer.inlineCallbacks + def on_POST(self, request, room_id, event_id): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + yield self.handlers.message_handler.purge_history(room_id, event_id) + + defer.returnValue((200, {})) + + class DeactivateAccountRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/deactivate/(?P[^/]*)") diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 98c917ce15..c3b498bb3d 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1281,6 +1281,12 @@ class EventsStore(SQLBaseStore): ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) + def delete_old_state(self, room_id, topological_ordering): + return self.runInteraction( + "delete_old_state", + self._delete_old_state_txn, room_id, topological_ordering + ) + def _delete_old_state_txn(self, txn, room_id, topological_ordering): """Deletes old room state """ -- cgit 1.5.1 From 1ee258430724618c7014bb176186c23b0b5b06f0 Mon Sep 17 00:00:00 2001 From: Kent Shikama Date: Tue, 5 Jul 2016 19:01:00 +0900 Subject: Fix pep8 --- synapse/config/password.py | 2 +- synapse/handlers/auth.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/config/password.py b/synapse/config/password.py index 7c5cb5f0e1..058a3a5346 100644 --- a/synapse/config/password.py +++ b/synapse/config/password.py @@ -34,4 +34,4 @@ class PasswordConfig(Config): # Change to a secret random string. # DO NOT CHANGE THIS AFTER INITIAL SETUP! #pepper: "HR32t0xZcQnzn3O0ZkEVuetdFvH1W6TeEPw6JjH0Cl+qflVOseGyFJlJR7ACLnywjN9" - """ \ No newline at end of file + """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index fd5fadf73d..be46681c64 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -750,7 +750,8 @@ class AuthHandler(BaseHandler): Returns: Hashed password (str). """ - return bcrypt.hashpw(password + self.hs.config.password_config.pepper, bcrypt.gensalt(self.bcrypt_rounds)) + return bcrypt.hashpw(password + self.hs.config.password_config.pepper, + bcrypt.gensalt(self.bcrypt_rounds)) def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. -- cgit 1.5.1 From 14362bf3590eb95a50201a84c8e16d5626b86249 Mon Sep 17 00:00:00 2001 From: Kent Shikama Date: Tue, 5 Jul 2016 19:12:53 +0900 Subject: Fix password config --- synapse/config/password.py | 2 +- synapse/handlers/auth.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/config/password.py b/synapse/config/password.py index 058a3a5346..00b1ea3df9 100644 --- a/synapse/config/password.py +++ b/synapse/config/password.py @@ -23,7 +23,7 @@ class PasswordConfig(Config): def read_config(self, config): password_config = config.get("password_config", {}) self.password_enabled = password_config.get("enabled", True) - self.pepper = password_config.get("pepper", "") + self.password_pepper = password_config.get("pepper", "") def default_config(self, config_dir_path, server_name, **kwargs): return """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index be46681c64..e259213a36 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -750,7 +750,7 @@ class AuthHandler(BaseHandler): Returns: Hashed password (str). """ - return bcrypt.hashpw(password + self.hs.config.password_config.pepper, + return bcrypt.hashpw(password + self.hs.config.password_pepper, bcrypt.gensalt(self.bcrypt_rounds)) def validate_hash(self, password, stored_hash): @@ -764,7 +764,7 @@ class AuthHandler(BaseHandler): Whether self.hash(password) == stored_hash (bool). """ if stored_hash: - return bcrypt.hashpw(password + self.hs.config.password_config.pepper, + return bcrypt.hashpw(password + self.hs.config.password_pepper, stored_hash.encode('utf-8')) == stored_hash else: return False -- cgit 1.5.1 From 8f8798bc0d572af103274fc07d3adac67ce7f51a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Jul 2016 15:30:25 +0100 Subject: Add ReadWriteLock for pagination and history prune --- synapse/handlers/message.py | 70 +++++++++++++++++++++++---------------------- synapse/storage/stream.py | 4 +-- 2 files changed, 38 insertions(+), 36 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 878809d50d..ad2753c1b5 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -26,7 +26,7 @@ from synapse.types import ( UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id ) from synapse.util import unwrapFirstError -from synapse.util.async import concurrently_execute, run_on_reactor +from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.logcontext import preserve_fn from synapse.visibility import filter_events_for_client @@ -50,6 +50,8 @@ class MessageHandler(BaseHandler): self.validator = EventValidator() self.snapshot_cache = SnapshotCache() + self.pagination_lock = ReadWriteLock() + @defer.inlineCallbacks def purge_history(self, room_id, event_id): event = yield self.store.get_event(event_id) @@ -59,9 +61,8 @@ class MessageHandler(BaseHandler): depth = event.depth - # TODO: Lock. - - yield self.store.delete_old_state(room_id, depth) + with (yield self.pagination_lock.write(room_id)): + yield self.store.delete_old_state(room_id, depth) @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, @@ -98,42 +99,43 @@ class MessageHandler(BaseHandler): source_config = pagin_config.get_source_config("room") - membership, member_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id - ) + with (yield self.pagination_lock.read(room_id)): + membership, member_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id + ) - if source_config.direction == 'b': - # if we're going backwards, we might need to backfill. This - # requires that we have a topo token. - if room_token.topological: - max_topo = room_token.topological - else: - max_topo = yield self.store.get_max_topological_token_for_stream_and_room( - room_id, room_token.stream - ) + if source_config.direction == 'b': + # if we're going backwards, we might need to backfill. This + # requires that we have a topo token. + if room_token.topological: + max_topo = room_token.topological + else: + max_topo = yield self.store.get_max_topological_token( + room_id, room_token.stream + ) + + if membership == Membership.LEAVE: + # If they have left the room then clamp the token to be before + # they left the room, to save the effort of loading from the + # database. + leave_token = yield self.store.get_topological_token_for_event( + member_event_id + ) + leave_token = RoomStreamToken.parse(leave_token) + if leave_token.topological < max_topo: + source_config.from_key = str(leave_token) - if membership == Membership.LEAVE: - # If they have left the room then clamp the token to be before - # they left the room, to save the effort of loading from the - # database. - leave_token = yield self.store.get_topological_token_for_event( - member_event_id + yield self.hs.get_handlers().federation_handler.maybe_backfill( + room_id, max_topo ) - leave_token = RoomStreamToken.parse(leave_token) - if leave_token.topological < max_topo: - source_config.from_key = str(leave_token) - yield self.hs.get_handlers().federation_handler.maybe_backfill( - room_id, max_topo + events, next_key = yield data_source.get_pagination_rows( + requester.user, source_config, room_id ) - events, next_key = yield data_source.get_pagination_rows( - requester.user, source_config, room_id - ) - - next_token = pagin_config.from_token.copy_and_replace( - "room_key", next_key - ) + next_token = pagin_config.from_token.copy_and_replace( + "room_key", next_key + ) if not events: defer.returnValue({ diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index b9ad965fd6..3dda2dab55 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -487,13 +487,13 @@ class StreamStore(SQLBaseStore): row["topological_ordering"], row["stream_ordering"],) ) - def get_max_topological_token_for_stream_and_room(self, room_id, stream_key): + def get_max_topological_token(self, room_id, stream_key): sql = ( "SELECT max(topological_ordering) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) return self._execute( - "get_max_topological_token_for_stream_and_room", None, + "get_max_topological_token", None, sql, room_id, stream_key, ).addCallback( lambda r: r[0][0] if r else 0 -- cgit 1.5.1 From 651faee698d5ff4806d1e0e7f5cd4c438bf434f1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Jul 2016 17:30:22 +0100 Subject: Add an admin option to shared secret registration --- scripts/register_new_matrix_user | 19 ++++++++++-- synapse/handlers/register.py | 4 ++- synapse/rest/client/v1/register.py | 1 + synapse/storage/registration.py | 61 ++++++++++++++++++++++++-------------- 4 files changed, 58 insertions(+), 27 deletions(-) (limited to 'synapse/handlers') diff --git a/scripts/register_new_matrix_user b/scripts/register_new_matrix_user index 6d055fd012..987bf32d1c 100755 --- a/scripts/register_new_matrix_user +++ b/scripts/register_new_matrix_user @@ -42,6 +42,7 @@ def request_registration(user, password, server_location, shared_secret, admin=F "password": password, "mac": mac, "type": "org.matrix.login.shared_secret", + "admin": admin, } server_location = server_location.rstrip("/") @@ -73,7 +74,7 @@ def request_registration(user, password, server_location, shared_secret, admin=F sys.exit(1) -def register_new_user(user, password, server_location, shared_secret): +def register_new_user(user, password, server_location, shared_secret, admin): if not user: try: default_user = getpass.getuser() @@ -104,7 +105,14 @@ def register_new_user(user, password, server_location, shared_secret): print "Passwords do not match" sys.exit(1) - request_registration(user, password, server_location, shared_secret) + if not admin: + admin = raw_input("Make admin [no]: ") + if admin in ("y", "yes", "true"): + admin = True + else: + admin = False + + request_registration(user, password, server_location, shared_secret, bool(admin)) if __name__ == "__main__": @@ -124,6 +132,11 @@ if __name__ == "__main__": default=None, help="New password for user. Will prompt if omitted.", ) + parser.add_argument( + "-a", "--admin", + action="store_true", + help="Register new user as an admin. Will prompt if omitted.", + ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument( @@ -156,4 +169,4 @@ if __name__ == "__main__": else: secret = args.shared_secret - register_new_user(args.user, args.password, args.server_url, secret) + register_new_user(args.user, args.password, args.server_url, secret, args.admin) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 88c82ba7d0..8c3381df8a 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -90,7 +90,8 @@ class RegistrationHandler(BaseHandler): password=None, generate_token=True, guest_access_token=None, - make_guest=False + make_guest=False, + admin=False, ): """Registers a new client on the server. @@ -141,6 +142,7 @@ class RegistrationHandler(BaseHandler): # If the user was a guest then they already have a profile None if was_guest else user.localpart ), + admin=admin, ) else: # autogen a sequential user ID diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 0eb7490e5d..25d63a0b0b 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -345,6 +345,7 @@ class RegisterRestServlet(ClientV1RestServlet): user_id, token = yield handler.register( localpart=user, password=password, + admin=bool(admin), ) self._remove_session(session) defer.returnValue({ diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 5c75dbab51..4999175ddb 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -77,7 +77,7 @@ class RegistrationStore(SQLBaseStore): @defer.inlineCallbacks def register(self, user_id, token, password_hash, was_guest=False, make_guest=False, appservice_id=None, - create_profile_with_localpart=None): + create_profile_with_localpart=None, admin=False): """Attempts to register an account. Args: @@ -104,6 +104,7 @@ class RegistrationStore(SQLBaseStore): make_guest, appservice_id, create_profile_with_localpart, + admin ) self.get_user_by_id.invalidate((user_id,)) self.is_guest.invalidate((user_id,)) @@ -118,6 +119,7 @@ class RegistrationStore(SQLBaseStore): make_guest, appservice_id, create_profile_with_localpart, + admin, ): now = int(self.clock.time()) @@ -125,29 +127,42 @@ class RegistrationStore(SQLBaseStore): try: if was_guest: - txn.execute("UPDATE users SET" - " password_hash = ?," - " upgrade_ts = ?," - " is_guest = ?" - " WHERE name = ?", - [password_hash, now, 1 if make_guest else 0, user_id]) + txn.execute( + "UPDATE users SET" + " password_hash = ?," + " upgrade_ts = ?," + " is_guest = ?," + " admin = ?" + " WHERE name = ?", + (password_hash, now, 1 if make_guest else 0, admin, user_id,) + ) + self._simple_update_one_txn( + txn, + "users", + keyvalues={ + "name": user_id, + }, + updatevalues={ + "password_hash": password_hash, + "upgrade_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": admin, + } + ) else: - txn.execute("INSERT INTO users " - "(" - " name," - " password_hash," - " creation_ts," - " is_guest," - " appservice_id" - ") " - "VALUES (?,?,?,?,?)", - [ - user_id, - password_hash, - now, - 1 if make_guest else 0, - appservice_id, - ]) + self._simple_insert_txn( + txn, + "users", + values={ + "name": user_id, + "password_hash": password_hash, + "creation_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": admin, + } + ) except self.database_engine.module.IntegrityError: raise StoreError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE -- cgit 1.5.1 From 0136a522b18a734db69171d60566f501c0ced663 Mon Sep 17 00:00:00 2001 From: Negar Fazeli Date: Fri, 8 Jul 2016 16:53:18 +0200 Subject: Bug fix: expire invalid access tokens --- synapse/api/auth.py | 3 +++ synapse/handlers/auth.py | 5 +++-- synapse/handlers/register.py | 6 +++--- synapse/rest/client/v1/register.py | 2 +- tests/api/test_auth.py | 31 ++++++++++++++++++++++++++++++- tests/handlers/test_register.py | 4 ++-- 6 files changed, 42 insertions(+), 9 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index a4d658a9d0..521a52e001 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -629,7 +629,10 @@ class Auth(object): except AuthError: # TODO(daniel): Remove this fallback when all existing access tokens # have been re-issued as macaroons. + if self.hs.config.expire_access_token: + raise ret = yield self._look_up_user_by_access_token(token) + defer.returnValue(ret) @defer.inlineCallbacks diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index e259213a36..5a0ed9d6b9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -637,12 +637,13 @@ class AuthHandler(BaseHandler): yield self.store.add_refresh_token_to_user(user_id, refresh_token) defer.returnValue(refresh_token) - def generate_access_token(self, user_id, extra_caveats=None): + def generate_access_token(self, user_id, extra_caveats=None, + duration_in_ms=(60 * 60 * 1000)): extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") now = self.hs.get_clock().time_msec() - expiry = now + (60 * 60 * 1000) + expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 8c3381df8a..6b33b27149 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -360,7 +360,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, localpart, displayname, duration_seconds, + def get_or_create_user(self, localpart, displayname, duration_in_ms, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -390,8 +390,8 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.auth_handler().generate_short_term_login_token( - user_id, duration_seconds) + token = self.auth_handler().generate_access_token( + user_id, None, duration_in_ms) if need_register: yield self.store.register( diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index ce7099b18f..8e1f1b7845 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -429,7 +429,7 @@ class CreateUserRestServlet(ClientV1RestServlet): user_id, token = yield handler.get_or_create_user( localpart=localpart, displayname=displayname, - duration_seconds=duration_seconds, + duration_in_ms=(duration_seconds * 1000), password_hash=password_hash ) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index ad269af0ec..960c23d631 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -281,7 +281,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user,)) - macaroon.add_first_party_caveat("time < 1") # ms + macaroon.add_first_party_caveat("time < -2000") # ms self.hs.clock.now = 5000 # seconds self.hs.config.expire_access_token = True @@ -293,3 +293,32 @@ class AuthTestCase(unittest.TestCase): yield self.auth.get_user_from_macaroon(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("Invalid macaroon", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_with_valid_duration(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user_id = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + macaroon.add_first_party_caveat("time < 900000000") # ms + + self.hs.clock.now = 5000 # seconds + self.hs.config.expire_access_token = True + + user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize()) + user = user_info["user"] + self.assertEqual(UserID.from_string(user_id), user) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 69a5e5b1d4..a7de3c7c17 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -42,12 +42,12 @@ class RegistrationTestCase(unittest.TestCase): http_client=None, expire_access_token=True) self.auth_handler = Mock( - generate_short_term_login_token=Mock(return_value='secret')) + generate_access_token=Mock(return_value='secret')) self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler self.hs.get_handlers().profile_handler = Mock() self.mock_handler = Mock(spec=[ - "generate_short_term_login_token", + "generate_access_token", ]) self.hs.get_auth_handler = Mock(return_value=self.auth_handler) -- cgit 1.5.1 From a98d2152049b0a61426ed3d8b6ac872a9ca3f535 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 14 Jul 2016 15:59:25 +0100 Subject: Add filter param to /messages API --- synapse/handlers/message.py | 16 ++++++++++++---- synapse/rest/client/v1/room.py | 11 ++++++++++- tests/storage/event_injector.py | 1 + tests/storage/test_events.py | 12 ++++++------ 4 files changed, 29 insertions(+), 11 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ad2753c1b5..dc76d34a52 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -66,7 +66,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, - as_client_event=True): + as_client_event=True, event_filter=None): """Get messages in a room. Args: @@ -75,11 +75,11 @@ class MessageHandler(BaseHandler): pagin_config (synapse.api.streams.PaginationConfig): The pagination config rules to apply, if any. as_client_event (bool): True to get events in client-server format. + event_filter (Filter): Filter to apply to results or None Returns: dict: Pagination API results """ user_id = requester.user.to_string() - data_source = self.hs.get_event_sources().sources["room"] if pagin_config.from_token: room_token = pagin_config.from_token.room_key @@ -129,8 +129,13 @@ class MessageHandler(BaseHandler): room_id, max_topo ) - events, next_key = yield data_source.get_pagination_rows( - requester.user, source_config, room_id + events, next_key = yield self.store.paginate_room_events( + room_id=room_id, + from_key=source_config.from_key, + to_key=source_config.to_key, + direction=source_config.direction, + limit=source_config.limit, + event_filter=event_filter, ) next_token = pagin_config.from_token.copy_and_replace( @@ -144,6 +149,9 @@ class MessageHandler(BaseHandler): "end": next_token.to_string(), }) + if event_filter: + events = event_filter.filter(events) + events = yield filter_events_for_client( self.store, user_id, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 86fbe2747d..866a1e9120 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -20,12 +20,14 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.api.errors import SynapseError, Codes, AuthError from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership +from synapse.api.filtering import Filter from synapse.types import UserID, RoomID, RoomAlias from synapse.events.utils import serialize_event from synapse.http.servlet import parse_json_object_from_request import logging import urllib +import ujson as json logger = logging.getLogger(__name__) @@ -327,12 +329,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet): request, default_limit=10, ) as_client_event = "raw" not in request.args + filter_bytes = request.args.get("filter", None) + if filter_bytes: + filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8") + event_filter = Filter(json.loads(filter_json)) + else: + event_filter = None handler = self.handlers.message_handler msgs = yield handler.get_messages( room_id=room_id, requester=requester, pagin_config=pagination_config, - as_client_event=as_client_event + as_client_event=as_client_event, + event_filter=event_filter, ) defer.returnValue((200, msgs)) diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py index f22ba8db89..38556da9a7 100644 --- a/tests/storage/event_injector.py +++ b/tests/storage/event_injector.py @@ -30,6 +30,7 @@ class EventInjector: def create_room(self, room): builder = self.event_builder_factory.new({ "type": EventTypes.Create, + "sender": "", "room_id": room.to_string(), "content": {}, }) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 18a6cff0c7..3762b38e37 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -37,7 +37,7 @@ class EventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_count_daily_messages(self): - self.db_pool.runQuery("DELETE FROM stats_reporting") + yield self.db_pool.runQuery("DELETE FROM stats_reporting") self.hs.clock.now = 100 @@ -60,7 +60,7 @@ class EventsStoreTestCase(unittest.TestCase): # it isn't old enough. count = yield self.store.count_daily_messages() self.assertIsNone(count) - self._assert_stats_reporting(1, self.hs.clock.now) + yield self._assert_stats_reporting(1, self.hs.clock.now) # Already reported yesterday, two new events from today. yield self.event_injector.inject_message(room, user, "Yeah they are!") @@ -68,21 +68,21 @@ class EventsStoreTestCase(unittest.TestCase): self.hs.clock.now += 60 * 60 * 24 count = yield self.store.count_daily_messages() self.assertEqual(2, count) # 2 since yesterday - self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever + yield self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever # Last reported too recently. yield self.event_injector.inject_message(room, user, "Who could disagree?") self.hs.clock.now += 60 * 60 * 22 count = yield self.store.count_daily_messages() self.assertIsNone(count) - self._assert_stats_reporting(4, self.hs.clock.now) + yield self._assert_stats_reporting(4, self.hs.clock.now) # Last reported too long ago yield self.event_injector.inject_message(room, user, "No one.") self.hs.clock.now += 60 * 60 * 26 count = yield self.store.count_daily_messages() self.assertIsNone(count) - self._assert_stats_reporting(5, self.hs.clock.now) + yield self._assert_stats_reporting(5, self.hs.clock.now) # And now let's actually report something yield self.event_injector.inject_message(room, user, "Indeed.") @@ -92,7 +92,7 @@ class EventsStoreTestCase(unittest.TestCase): self.hs.clock.now += (60 * 60 * 24) + 50 count = yield self.store.count_daily_messages() self.assertEqual(3, count) - self._assert_stats_reporting(8, self.hs.clock.now) + yield self._assert_stats_reporting(8, self.hs.clock.now) @defer.inlineCallbacks def _get_last_stream_token(self): -- cgit 1.5.1 From ebdafd8114d1aed631a3497ad142f79efa9face7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 14 Jul 2016 16:49:37 +0100 Subject: Check sender signed event --- synapse/api/auth.py | 10 ++++++++-- synapse/handlers/federation.py | 4 ++-- synapse/state.py | 4 ++-- 3 files changed, 12 insertions(+), 6 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e05defd7d8..e2f40ee65a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -63,7 +63,7 @@ class Auth(object): "user_id = ", ]) - def check(self, event, auth_events): + def check(self, event, auth_events, do_sig_check=True): """ Checks if this event is correctly authed. Args: @@ -79,6 +79,13 @@ class Auth(object): if not hasattr(event, "room_id"): raise AuthError(500, "Event has no room_id: %s" % event) + + sender_domain = get_domain_from_id(event.sender) + + # Check the sender's domain has signed the event + if do_sig_check and not event.signatures.get(sender_domain): + raise AuthError(403, "Event not signed by sending server") + if auth_events is None: # Oh, we don't know what the state of the room was, so we # are trusting that this is allowed (at least for now) @@ -87,7 +94,6 @@ class Auth(object): if event.type == EventTypes.Create: room_id_domain = get_domain_from_id(event.room_id) - sender_domain = get_domain_from_id(event.sender) if room_id_domain != sender_domain: raise AuthError( 403, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 351b218247..4e8ffa8f7b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -688,7 +688,7 @@ class FederationHandler(BaseHandler): logger.warn("Failed to create join %r because %s", event, e) raise e - self.auth.check(event, auth_events=context.current_state) + self.auth.check(event, auth_events=context.current_state, do_sig_check=False) defer.returnValue(event) @@ -918,7 +918,7 @@ class FederationHandler(BaseHandler): ) try: - self.auth.check(event, auth_events=context.current_state) + self.auth.check(event, auth_events=context.current_state, do_sig_check=False) except AuthError as e: logger.warn("Failed to create new leave %r because %s", event, e) raise e diff --git a/synapse/state.py b/synapse/state.py index d0f76dc4f5..d7d08570c9 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -379,7 +379,7 @@ class StateHandler(object): try: # FIXME: hs.get_auth() is bad style, but we need to do it to # get around circular deps. - self.hs.get_auth().check(event, auth_events) + self.hs.get_auth().check(event, auth_events, do_sig_check=False) prev_event = event except AuthError: return prev_event @@ -391,7 +391,7 @@ class StateHandler(object): try: # FIXME: hs.get_auth() is bad style, but we need to do it to # get around circular deps. - self.hs.get_auth().check(event, auth_events) + self.hs.get_auth().check(event, auth_events, do_sig_check=False) return event except AuthError: pass -- cgit 1.5.1 From 9e1b43bcbf46c38510cd8348b7df3eb5f6374e81 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 15 Jul 2016 09:29:54 +0100 Subject: Comment --- synapse/handlers/federation.py | 4 ++++ synapse/state.py | 2 ++ 2 files changed, 6 insertions(+) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4e8ffa8f7b..7622962d46 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -688,6 +688,8 @@ class FederationHandler(BaseHandler): logger.warn("Failed to create join %r because %s", event, e) raise e + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_join_request` self.auth.check(event, auth_events=context.current_state, do_sig_check=False) defer.returnValue(event) @@ -918,6 +920,8 @@ class FederationHandler(BaseHandler): ) try: + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_leave_request` self.auth.check(event, auth_events=context.current_state, do_sig_check=False) except AuthError as e: logger.warn("Failed to create new leave %r because %s", event, e) diff --git a/synapse/state.py b/synapse/state.py index d7d08570c9..ef1bc470be 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -379,6 +379,7 @@ class StateHandler(object): try: # FIXME: hs.get_auth() is bad style, but we need to do it to # get around circular deps. + # The signatures have already been checked at this point self.hs.get_auth().check(event, auth_events, do_sig_check=False) prev_event = event except AuthError: @@ -391,6 +392,7 @@ class StateHandler(object): try: # FIXME: hs.get_auth() is bad style, but we need to do it to # get around circular deps. + # The signatures have already been checked at this point self.hs.get_auth().check(event, auth_events, do_sig_check=False) return event except AuthError: -- cgit 1.5.1 From 6344db659f0d4c57551f1da6456dcaa724d5beb2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 18 Jul 2016 09:47:33 +0100 Subject: Fix a doc-comment The `store` in a handler is a generic DataStore, not just an events.StateStore. --- synapse/handlers/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index c904c6c500..d00685c389 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -31,7 +31,7 @@ class BaseHandler(object): Common base class for the event handlers. Attributes: - store (synapse.storage.events.StateStore): + store (synapse.storage.DataStore): state_handler (synapse.state.StateHandler): """ -- cgit 1.5.1 From dcfd71aa4c4a1d3d71356fd2f5d854fb1db8fafa Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 Jul 2016 12:34:23 +0100 Subject: Refactor login flow Make sure that we have the canonical user_id *before* calling get_login_tuple_for_user_id. Replace login_with_password with a method which just validates the password, and have the caller call get_login_tuple_for_user_id. This brings the password flow into line with the other flows, and will give us a place to register the device_id if necessary. --- synapse/handlers/auth.py | 106 ++++++++++++++++++++++------------------ synapse/rest/client/v1/login.py | 41 +++++++++------- 2 files changed, 82 insertions(+), 65 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 5a0ed9d6b9..983994fa95 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -230,7 +230,6 @@ class AuthHandler(BaseHandler): sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) - @defer.inlineCallbacks def _check_password_auth(self, authdict, _): if "user" not in authdict or "password" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) @@ -240,11 +239,7 @@ class AuthHandler(BaseHandler): if not user_id.startswith('@'): user_id = UserID.create(user_id, self.hs.hostname).to_string() - if not (yield self._check_password(user_id, password)): - logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - defer.returnValue(user_id) + return self._check_password(user_id, password) @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip): @@ -348,67 +343,66 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - @defer.inlineCallbacks - def login_with_password(self, user_id, password): + def validate_password_login(self, user_id, password): """ Authenticates the user with their username and password. Used only by the v1 login API. Args: - user_id (str): User ID + user_id (str): complete @user:id password (str): Password Returns: - A tuple of: - The user's ID. - The access token for the user's session. - The refresh token for the user's session. + defer.Deferred: (str) canonical user id Raises: - StoreError if there was a problem storing the token. + StoreError if there was a problem accessing the database LoginError if there was an authentication problem. """ - - if not (yield self._check_password(user_id, password)): - logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - logger.info("Logging in user %s", user_id) - access_token = yield self.issue_access_token(user_id) - refresh_token = yield self.issue_refresh_token(user_id) - defer.returnValue((user_id, access_token, refresh_token)) + return self._check_password(user_id, password) @defer.inlineCallbacks def get_login_tuple_for_user_id(self, user_id): """ Gets login tuple for the user with the given user ID. + + Creates a new access/refresh token for the user. + The user is assumed to have been authenticated by some other - machanism (e.g. CAS) + machanism (e.g. CAS), and the user_id converted to the canonical case. Args: - user_id (str): User ID + user_id (str): canonical User ID Returns: A tuple of: - The user's ID. The access token for the user's session. The refresh token for the user's session. Raises: StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id) - logger.info("Logging in user %s", user_id) access_token = yield self.issue_access_token(user_id) refresh_token = yield self.issue_refresh_token(user_id) - defer.returnValue((user_id, access_token, refresh_token)) + defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks - def does_user_exist(self, user_id): + def check_user_exists(self, user_id): + """ + Checks to see if a user with the given id exists. Will check case + insensitively, but return None if there are multiple inexact matches. + + Args: + (str) user_id: complete @user:id + + Returns: + defer.Deferred: (str) canonical_user_id, or None if zero or + multiple matches + """ try: - yield self._find_user_id_and_pwd_hash(user_id) - defer.returnValue(True) + res = yield self._find_user_id_and_pwd_hash(user_id) + defer.returnValue(res[0]) except LoginError: - defer.returnValue(False) + defer.returnValue(None) @defer.inlineCallbacks def _find_user_id_and_pwd_hash(self, user_id): @@ -438,27 +432,45 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def _check_password(self, user_id, password): - """ + """Authenticate a user against the LDAP and local databases. + + user_id is checked case insensitively against the local database, but + will throw if there are multiple inexact matches. + + Args: + user_id (str): complete @user:id Returns: - True if the user_id successfully authenticated + (str) the canonical_user_id + Raises: + LoginError if the password was incorrect """ valid_ldap = yield self._check_ldap_password(user_id, password) if valid_ldap: - defer.returnValue(True) + defer.returnValue(user_id) - valid_local_password = yield self._check_local_password(user_id, password) - if valid_local_password: - defer.returnValue(True) - - defer.returnValue(False) + result = yield self._check_local_password(user_id, password) + defer.returnValue(result) @defer.inlineCallbacks def _check_local_password(self, user_id, password): - try: - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - defer.returnValue(self.validate_hash(password, password_hash)) - except LoginError: - defer.returnValue(False) + """Authenticate a user against the local password database. + + user_id is checked case insensitively, but will throw if there are + multiple inexact matches. + + Args: + user_id (str): complete @user:id + Returns: + (str) the canonical_user_id + Raises: + LoginError if the password was incorrect + """ + user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) + result = self.validate_hash(password, password_hash) + if not result: + logger.warn("Failed password login for user %s", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + defer.returnValue(user_id) @defer.inlineCallbacks def _check_ldap_password(self, user_id, password): @@ -570,7 +582,7 @@ class AuthHandler(BaseHandler): ) # check for existing account, if none exists, create one - if not (yield self.does_user_exist(user_id)): + if not (yield self.check_user_exists(user_id)): # query user metadata for account creation query = "({prop}={value})".format( prop=self.ldap_attributes['uid'], diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 8df9d10efa..a1f2ba8773 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -145,10 +145,13 @@ class LoginRestServlet(ClientV1RestServlet): ).to_string() auth_handler = self.auth_handler - user_id, access_token, refresh_token = yield auth_handler.login_with_password( + user_id = yield auth_handler.validate_password_login( user_id=user_id, - password=login_submission["password"]) - + password=login_submission["password"], + ) + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id(user_id) + ) result = { "user_id": user_id, # may have changed "access_token": access_token, @@ -165,7 +168,7 @@ class LoginRestServlet(ClientV1RestServlet): user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) - user_id, access_token, refresh_token = ( + access_token, refresh_token = ( yield auth_handler.get_login_tuple_for_user_id(user_id) ) result = { @@ -196,13 +199,15 @@ class LoginRestServlet(ClientV1RestServlet): user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + registered_user_id = yield auth_handler.check_user_exists(user_id) + if registered_user_id: + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id( + registered_user_id + ) ) result = { - "user_id": user_id, # may have changed + "user_id": registered_user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, @@ -245,13 +250,13 @@ class LoginRestServlet(ClientV1RestServlet): user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + registered_user_id = yield auth_handler.check_user_exists(user_id) + if registered_user_id: + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id(registered_user_id) ) result = { - "user_id": user_id, # may have changed + "user_id": registered_user_id, "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, @@ -414,13 +419,13 @@ class CasTicketServlet(ClientV1RestServlet): user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if not user_exists: - user_id, _ = ( + registered_user_id = yield auth_handler.check_user_exists(user_id) + if not registered_user_id: + registered_user_id, _ = ( yield self.handlers.registration_handler.register(localpart=user) ) - login_token = auth_handler.generate_short_term_login_token(user_id) + login_token = auth_handler.generate_short_term_login_token(registered_user_id) redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, login_token) request.redirect(redirect_url) -- cgit 1.5.1 From f863a52ceacf69ab19b073383be80603a2f51c0a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 Jul 2016 13:19:07 +0100 Subject: Add device_id support to /login Add a 'devices' table to the storage, as well as a 'device_id' column to refresh_tokens. Allow the client to pass a device_id, and initial_device_display_name, to /login. If login is successful, then register the device in the devices table if it wasn't known already. If no device_id was supplied, make one up. Associate the device_id with the access token and refresh token, so that we can get at it again later. Ensure that the device_id is copied from the refresh token to the access_token when the token is refreshed. --- synapse/handlers/auth.py | 19 +++--- synapse/handlers/device.py | 71 ++++++++++++++++++++ synapse/rest/client/v1/login.py | 39 ++++++++++- synapse/rest/client/v2_alpha/tokenrefresh.py | 10 ++- synapse/server.py | 5 ++ synapse/storage/__init__.py | 3 + synapse/storage/devices.py | 77 ++++++++++++++++++++++ synapse/storage/registration.py | 28 +++++--- synapse/storage/schema/delta/33/devices.sql | 21 ++++++ .../schema/delta/33/refreshtoken_device.sql | 16 +++++ tests/handlers/test_device.py | 75 +++++++++++++++++++++ tests/storage/test_registration.py | 21 ++++-- 12 files changed, 354 insertions(+), 31 deletions(-) create mode 100644 synapse/handlers/device.py create mode 100644 synapse/storage/devices.py create mode 100644 synapse/storage/schema/delta/33/devices.sql create mode 100644 synapse/storage/schema/delta/33/refreshtoken_device.sql create mode 100644 tests/handlers/test_device.py (limited to 'synapse/handlers') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 983994fa95..ce9bc18849 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -361,7 +361,7 @@ class AuthHandler(BaseHandler): return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id): + def get_login_tuple_for_user_id(self, user_id, device_id=None): """ Gets login tuple for the user with the given user ID. @@ -372,6 +372,7 @@ class AuthHandler(BaseHandler): Args: user_id (str): canonical User ID + device_id (str): the device ID to associate with the access token Returns: A tuple of: The access token for the user's session. @@ -380,9 +381,9 @@ class AuthHandler(BaseHandler): StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - logger.info("Logging in user %s", user_id) - access_token = yield self.issue_access_token(user_id) - refresh_token = yield self.issue_refresh_token(user_id) + logger.info("Logging in user %s on device %s", user_id, device_id) + access_token = yield self.issue_access_token(user_id, device_id) + refresh_token = yield self.issue_refresh_token(user_id, device_id) defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks @@ -638,15 +639,17 @@ class AuthHandler(BaseHandler): defer.returnValue(False) @defer.inlineCallbacks - def issue_access_token(self, user_id): + def issue_access_token(self, user_id, device_id=None): access_token = self.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token) + yield self.store.add_access_token_to_user(user_id, access_token, + device_id) defer.returnValue(access_token) @defer.inlineCallbacks - def issue_refresh_token(self, user_id): + def issue_refresh_token(self, user_id, device_id=None): refresh_token = self.generate_refresh_token(user_id) - yield self.store.add_refresh_token_to_user(user_id, refresh_token) + yield self.store.add_refresh_token_to_user(user_id, refresh_token, + device_id) defer.returnValue(refresh_token) def generate_access_token(self, user_id, extra_caveats=None, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py new file mode 100644 index 0000000000..8d7d9874f8 --- /dev/null +++ b/synapse/handlers/device.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.errors import StoreError +from synapse.util import stringutils +from twisted.internet import defer +from ._base import BaseHandler + +import logging + +logger = logging.getLogger(__name__) + + +class DeviceHandler(BaseHandler): + def __init__(self, hs): + super(DeviceHandler, self).__init__(hs) + + @defer.inlineCallbacks + def check_device_registered(self, user_id, device_id, + initial_device_display_name): + """ + If the given device has not been registered, register it with the + supplied display name. + + If no device_id is supplied, we make one up. + + Args: + user_id (str): @user:id + device_id (str | None): device id supplied by client + initial_device_display_name (str | None): device display name from + client + Returns: + str: device id (generated if none was supplied) + """ + if device_id is not None: + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=True, + ) + defer.returnValue(device_id) + + # if the device id is not specified, we'll autogen one, but loop a few + # times in case of a clash. + attempts = 0 + while attempts < 5: + try: + device_id = stringutils.random_string_with_symbols(16) + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=False, + ) + defer.returnValue(device_id) + except StoreError: + attempts += 1 + + raise StoreError(500, "Couldn't generate a device ID.") diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a1f2ba8773..e8b791519c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet): self.servername = hs.config.server_name self.http_client = hs.get_simple_http_client() self.auth_handler = self.hs.get_auth_handler() + self.device_handler = self.hs.get_device_handler() def on_GET(self, request): flows = [] @@ -149,14 +150,16 @@ class LoginRestServlet(ClientV1RestServlet): user_id=user_id, password=login_submission["password"], ) + device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -168,14 +171,16 @@ class LoginRestServlet(ClientV1RestServlet): user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) + device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -252,8 +257,13 @@ class LoginRestServlet(ClientV1RestServlet): auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if registered_user_id: + device_id = yield self._register_device( + registered_user_id, login_submission + ) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(registered_user_id) + yield auth_handler.get_login_tuple_for_user_id( + registered_user_id, device_id + ) ) result = { "user_id": registered_user_id, @@ -262,6 +272,9 @@ class LoginRestServlet(ClientV1RestServlet): "home_server": self.hs.hostname, } else: + # TODO: we should probably check that the register isn't going + # to fonx/change our user_id before registering the device + device_id = yield self._register_device(user_id, login_submission) user_id, access_token = ( yield self.handlers.registration_handler.register(localpart=user) ) @@ -300,6 +313,26 @@ class LoginRestServlet(ClientV1RestServlet): return (user, attributes) + def _register_device(self, user_id, login_submission): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) login_submission: dictionary supplied to /login call, from + which we pull device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get( + "initial_device_display_name") + return self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + class SAML2RestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login/saml2", releases=()) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 8270e8787f..0d312c91d4 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -39,9 +39,13 @@ class TokenRefreshRestServlet(RestServlet): try: old_refresh_token = body["refresh_token"] auth_handler = self.hs.get_auth_handler() - (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( - old_refresh_token, auth_handler.generate_refresh_token) - new_access_token = yield auth_handler.issue_access_token(user_id) + refresh_result = yield self.store.exchange_refresh_token( + old_refresh_token, auth_handler.generate_refresh_token + ) + (user_id, new_refresh_token, device_id) = refresh_result + new_access_token = yield auth_handler.issue_access_token( + user_id, device_id + ) defer.returnValue((200, { "access_token": new_access_token, "refresh_token": new_refresh_token, diff --git a/synapse/server.py b/synapse/server.py index d49a1a8a96..e8b166990d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -25,6 +25,7 @@ from twisted.enterprise import adbapi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.appservice.api import ApplicationServiceApi from synapse.federation import initialize_http_replication +from synapse.handlers.device import DeviceHandler from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.notifier import Notifier from synapse.api.auth import Auth @@ -92,6 +93,7 @@ class HomeServer(object): 'typing_handler', 'room_list_handler', 'auth_handler', + 'device_handler', 'application_service_api', 'application_service_scheduler', 'application_service_handler', @@ -197,6 +199,9 @@ class HomeServer(object): def build_auth_handler(self): return AuthHandler(self) + def build_device_handler(self): + return DeviceHandler(self) + def build_application_service_api(self): return ApplicationServiceApi(self) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 1c93e18f9d..73fb334dd6 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -14,6 +14,8 @@ # limitations under the License. from twisted.internet import defer + +from synapse.storage.devices import DeviceStore from .appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore ) @@ -80,6 +82,7 @@ class DataStore(RoomMemberStore, RoomStore, EventPushActionsStore, OpenIdStore, ClientIpStore, + DeviceStore, ): def __init__(self, db_conn, hs): diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py new file mode 100644 index 0000000000..9065e96d28 --- /dev/null +++ b/synapse/storage/devices.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from ._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class DeviceStore(SQLBaseStore): + @defer.inlineCallbacks + def store_device(self, user_id, device_id, + initial_device_display_name, + ignore_if_known=True): + """Ensure the given device is known; add it to the store if not + + Args: + user_id (str): id of user associated with the device + device_id (str): id of device + initial_device_display_name (str): initial displayname of the + device + ignore_if_known (bool): ignore integrity errors which mean the + device is already known + Returns: + defer.Deferred + Raises: + StoreError: if ignore_if_known is False and the device was already + known + """ + try: + yield self._simple_insert( + "devices", + values={ + "user_id": user_id, + "device_id": device_id, + "display_name": initial_device_display_name + }, + desc="store_device", + or_ignore=ignore_if_known, + ) + except Exception as e: + logger.error("store_device with device_id=%s failed: %s", + device_id, e) + raise StoreError(500, "Problem storing device.") + + def get_device(self, user_id, device_id): + """Retrieve a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred for a namedtuple containing the device information + Raises: + StoreError: if the device is not found + """ + return self._simple_select_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcols=("user_id", "device_id", "display_name"), + desc="get_device", + ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index d957a629dc..26ef1cfd8a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -31,12 +31,14 @@ class RegistrationStore(SQLBaseStore): self.clock = hs.get_clock() @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token): + def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. Args: user_id (str): The user ID. token (str): The new access token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -47,18 +49,21 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_access_token_to_user", ) @defer.inlineCallbacks - def add_refresh_token_to_user(self, user_id, token): + def add_refresh_token_to_user(self, user_id, token, device_id=None): """Adds a refresh token for the given user. Args: user_id (str): The user ID. token (str): The new refresh token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -69,7 +74,8 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_refresh_token_to_user", ) @@ -291,18 +297,18 @@ class RegistrationStore(SQLBaseStore): ) def exchange_refresh_token(self, refresh_token, token_generator): - """Exchange a refresh token for a new access token and refresh token. + """Exchange a refresh token for a new one. Doing so invalidates the old refresh token - refresh tokens are single use. Args: - token (str): The refresh token of a user. + refresh_token (str): The refresh token of a user. token_generator (fn: str -> str): Function which, when given a user ID, returns a unique refresh token for that user. This function must never return the same value twice. Returns: - tuple of (user_id, refresh_token) + tuple of (user_id, new_refresh_token, device_id) Raises: StoreError if no user was found with that refresh token. """ @@ -314,12 +320,13 @@ class RegistrationStore(SQLBaseStore): ) def _exchange_refresh_token(self, txn, old_token, token_generator): - sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" + sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?" txn.execute(sql, (old_token,)) rows = self.cursor_to_dict(txn) if not rows: raise StoreError(403, "Did not recognize refresh token") user_id = rows[0]["user_id"] + device_id = rows[0]["device_id"] # TODO(danielwh): Maybe perform a validation on the macaroon that # macaroon.user_id == user_id. @@ -328,7 +335,7 @@ class RegistrationStore(SQLBaseStore): sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" txn.execute(sql, (new_token, old_token,)) - return user_id, new_token + return user_id, new_token, device_id @defer.inlineCallbacks def is_server_admin(self, user): @@ -356,7 +363,8 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id" + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/schema/delta/33/devices.sql new file mode 100644 index 0000000000..eca7268d82 --- /dev/null +++ b/synapse/storage/schema/delta/33/devices.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE devices ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + display_name TEXT, + CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) +); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/33/refreshtoken_device.sql new file mode 100644 index 0000000000..b21da00dde --- /dev/null +++ b/synapse/storage/schema/delta/33/refreshtoken_device.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE refresh_tokens ADD COLUMN device_id BIGINT; diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py new file mode 100644 index 0000000000..cc6512ccc7 --- /dev/null +++ b/tests/handlers/test_device.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.handlers.device import DeviceHandler +from tests import unittest +from tests.utils import setup_test_homeserver + + +class DeviceHandlers(object): + def __init__(self, hs): + self.device_handler = DeviceHandler(hs) + + +class DeviceTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver(handlers=None) + self.hs.handlers = handlers = DeviceHandlers(self.hs) + self.handler = handlers.device_handler + + @defer.inlineCallbacks + def test_device_is_created_if_doesnt_exist(self): + res = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="display name" + ) + self.assertEqual(res, "fco") + + dev = yield self.handler.store.get_device("boris", "fco") + self.assertEqual(dev["display_name"], "display name") + + @defer.inlineCallbacks + def test_device_is_preserved_if_exists(self): + res1 = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="display name" + ) + self.assertEqual(res1, "fco") + + res2 = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="new display name" + ) + self.assertEqual(res2, "fco") + + dev = yield self.handler.store.get_device("boris", "fco") + self.assertEqual(dev["display_name"], "display name") + + @defer.inlineCallbacks + def test_device_id_is_made_up_if_unspecified(self): + device_id = yield self.handler.check_device_registered( + user_id="theresa", + device_id=None, + initial_device_display_name="display" + ) + + dev = yield self.handler.store.get_device("theresa", device_id) + self.assertEqual(dev["display_name"], "display") diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index b8384c98d8..b03ca303a2 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -38,6 +38,7 @@ class RegistrationStoreTestCase(unittest.TestCase): "BcDeFgHiJkLmNoPqRsTuVwXyZa" ] self.pwhash = "{xx1}123456789" + self.device_id = "akgjhdjklgshg" @defer.inlineCallbacks def test_register(self): @@ -64,13 +65,15 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_add_tokens(self): yield self.store.register(self.user_id, self.tokens[0], self.pwhash) - yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) + yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], + self.device_id) result = yield self.store.get_user_by_access_token(self.tokens[1]) self.assertDictContainsSubset( { "name": self.user_id, + "device_id": self.device_id, }, result ) @@ -80,20 +83,24 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_exchange_refresh_token_valid(self): uid = stringutils.random_string(32) + device_id = stringutils.random_string(16) generator = TokenGenerator() last_token = generator.generate(uid) self.db_pool.runQuery( - "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", - (uid, last_token,)) + "INSERT INTO refresh_tokens(user_id, token, device_id) " + "VALUES(?,?,?)", + (uid, last_token, device_id)) - (found_user_id, refresh_token) = yield self.store.exchange_refresh_token( - last_token, generator.generate) + (found_user_id, refresh_token, device_id) = \ + yield self.store.exchange_refresh_token(last_token, + generator.generate) self.assertEqual(uid, found_user_id) rows = yield self.db_pool.runQuery( - "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, )) - self.assertEqual([(refresh_token,)], rows) + "SELECT token, device_id FROM refresh_tokens WHERE user_id = ?", + (uid, )) + self.assertEqual([(refresh_token, device_id)], rows) # We issued token 1, then exchanged it for token 2 expected_refresh_token = u"%s-%d" % (uid, 2,) self.assertEqual(expected_refresh_token, refresh_token) -- cgit 1.5.1 From 7e554aac86144ebde529aae259cd0895d4078f23 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 19 Jul 2016 10:18:40 +0100 Subject: Update docstring on Handlers. To indicate it is deprecated. --- synapse/handlers/__init__.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index d28e07f0d9..c512077cb5 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -31,10 +31,22 @@ from .search import SearchHandler class Handlers(object): - """ A collection of all the event handlers. + """ + Deprecated. + + At some point most of the classes whose name ended "Handler" were + accessed through this class. + + However this makes it painful to unit test the handlers and to run cut + down versions of synapse that only use specific handlers because using a + single handler required creating all of the handlers. So some of the + handlers have been lifted out of the Handlers object and are now accessed + directly through the homeserver object itself. + + Any new handlers should follow the new pattern of being accessed through + the homeserver object and should not be added to the Handlers object. - There's no need to lazily create these; we'll just make them all eagerly - at construction time. + The remaining handlers should be moved out of the handlers object. """ def __init__(self, hs): -- cgit 1.5.1 From c41d52a04221d478220ede7ab389299918f113ca Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 19 Jul 2016 10:28:27 +0100 Subject: Summary line --- synapse/handlers/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index c512077cb5..1a50a2ec98 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -31,8 +31,7 @@ from .search import SearchHandler class Handlers(object): - """ - Deprecated. + """ Deprecated. A collection of handlers. At some point most of the classes whose name ended "Handler" were accessed through this class. -- cgit 1.5.1 From 40cbffb2d2ca0166f1377ac4ec5988046ea4ca10 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 19 Jul 2016 18:46:19 +0100 Subject: Further registration refactoring * `RegistrationHandler.appservice_register` no longer issues an access token: instead it is left for the caller to do it. (There are two of these, one in `synapse/rest/client/v1/register.py`, which now simply calls `AuthHandler.issue_access_token`, and the other in `synapse/rest/client/v2_alpha/register.py`, which is covered below). * In `synapse/rest/client/v2_alpha/register.py`, move the generation of access_tokens into `_create_registration_details`. This means that the normal flow no longer needs to call `AuthHandler.issue_access_token`; the shared-secret flow can tell `RegistrationHandler.register` not to generate a token; and the appservice flow continues to work despite the above change. --- synapse/handlers/register.py | 13 +++++--- synapse/rest/client/v1/register.py | 4 ++- synapse/rest/client/v2_alpha/register.py | 50 +++++++++++++++++++++-------- synapse/storage/registration.py | 6 ++-- tests/rest/client/v2_alpha/test_register.py | 6 +++- 5 files changed, 57 insertions(+), 22 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 6b33b27149..94b19d0cb0 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -99,8 +99,13 @@ class RegistrationHandler(BaseHandler): localpart : The local part of the user ID to register. If None, one will be generated. password (str) : The password to assign to this user so they can - login again. This can be None which means they cannot login again - via a password (e.g. the user is an application service user). + login again. This can be None which means they cannot login again + via a password (e.g. the user is an application service user). + generate_token (bool): Whether a new access token should be + generated. Having this be True should be considered deprecated, + since it offers no means of associating a device_id with the + access_token. Instead you should call auth_handler.issue_access_token + after registration. Returns: A tuple of (user_id, access_token). Raises: @@ -196,15 +201,13 @@ class RegistrationHandler(BaseHandler): user_id, allowed_appservice=service ) - token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, - token=token, password_hash="", appservice_id=service_id, create_profile_with_localpart=user.localpart, ) - defer.returnValue((user_id, token)) + defer.returnValue(user_id) @defer.inlineCallbacks def check_recaptcha(self, ip, private_key, challenge, response): diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 8e1f1b7845..28b59952c3 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -60,6 +60,7 @@ class RegisterRestServlet(ClientV1RestServlet): # TODO: persistent storage self.sessions = {} self.enable_registration = hs.config.enable_registration + self.auth_handler = hs.get_auth_handler() def on_GET(self, request): if self.hs.config.enable_registration_captcha: @@ -299,9 +300,10 @@ class RegisterRestServlet(ClientV1RestServlet): user_localpart = register_json["user"].encode("utf-8") handler = self.handlers.registration_handler - (user_id, token) = yield handler.appservice_register( + user_id = yield handler.appservice_register( user_localpart, as_token ) + token = yield self.auth_handler.issue_access_token(user_id) self._remove_session(session) defer.returnValue({ "user_id": user_id, diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 5db953a1e3..04004cfbbd 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -226,19 +226,17 @@ class RegisterRestServlet(RestServlet): add_email = True - access_token = yield self.auth_handler.issue_access_token( + result = yield self._create_registration_details( registered_user_id ) if add_email and result and LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] yield self._register_email_threepid( - registered_user_id, threepid, access_token, + registered_user_id, threepid, result["access_token"], params.get("bind_email") ) - result = yield self._create_registration_details(registered_user_id, - access_token) defer.returnValue((200, result)) def on_OPTIONS(self, _): @@ -246,10 +244,10 @@ class RegisterRestServlet(RestServlet): @defer.inlineCallbacks def _do_appservice_registration(self, username, as_token): - (user_id, token) = yield self.registration_handler.appservice_register( + user_id = yield self.registration_handler.appservice_register( username, as_token ) - defer.returnValue((yield self._create_registration_details(user_id, token))) + defer.returnValue((yield self._create_registration_details(user_id))) @defer.inlineCallbacks def _do_shared_secret_registration(self, username, password, mac): @@ -273,10 +271,12 @@ class RegisterRestServlet(RestServlet): 403, "HMAC incorrect", ) - (user_id, token) = yield self.registration_handler.register( - localpart=username, password=password + (user_id, _) = yield self.registration_handler.register( + localpart=username, password=password, generate_token=False, ) - defer.returnValue((yield self._create_registration_details(user_id, token))) + + result = yield self._create_registration_details(user_id) + defer.returnValue(result) @defer.inlineCallbacks def _register_email_threepid(self, user_id, threepid, token, bind_email): @@ -349,11 +349,31 @@ class RegisterRestServlet(RestServlet): defer.returnValue() @defer.inlineCallbacks - def _create_registration_details(self, user_id, token): - refresh_token = yield self.auth_handler.issue_refresh_token(user_id) + def _create_registration_details(self, user_id): + """Complete registration of newly-registered user + + Issues access_token and refresh_token, and builds the success response + body. + + Args: + (str) user_id: full canonical @user:id + + + Returns: + defer.Deferred: (object) dictionary for response from /register + """ + + access_token = yield self.auth_handler.issue_access_token( + user_id + ) + + refresh_token = yield self.auth_handler.issue_refresh_token( + user_id + ) + defer.returnValue({ "user_id": user_id, - "access_token": token, + "access_token": access_token, "home_server": self.hs.hostname, "refresh_token": refresh_token, }) @@ -366,7 +386,11 @@ class RegisterRestServlet(RestServlet): generate_token=False, make_guest=True ) - access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"]) + access_token = self.auth_handler.generate_access_token( + user_id, ["guest = true"] + ) + # XXX the "guest" caveat is not copied by /tokenrefresh. That's ok + # so long as we don't return a refresh_token here. defer.returnValue((200, { "user_id": user_id, "access_token": access_token, diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 26ef1cfd8a..9a92b35361 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -81,14 +81,16 @@ class RegistrationStore(SQLBaseStore): ) @defer.inlineCallbacks - def register(self, user_id, token, password_hash, + def register(self, user_id, token=None, password_hash=None, was_guest=False, make_guest=False, appservice_id=None, create_profile_with_localpart=None, admin=False): """Attempts to register an account. Args: user_id (str): The desired user ID to register. - token (str): The desired access token to use for this user. + token (str): The desired access token to use for this user. If this + is not None, the given access token is associated with the user + id. password_hash (str): Optional. The password hash for this user. was_guest (bool): Optional. Whether this is a guest account being upgraded to a non-guest account. diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 9a4215fef7..ccbb8776d3 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -61,8 +61,10 @@ class RegisterRestServletTestCase(unittest.TestCase): "id": "1234" } self.registration_handler.appservice_register = Mock( - return_value=(user_id, token) + return_value=user_id ) + self.auth_handler.issue_access_token = Mock(return_value=token) + (code, result) = yield self.servlet.on_POST(self.request) self.assertEquals(code, 200) det_data = { @@ -126,6 +128,8 @@ class RegisterRestServletTestCase(unittest.TestCase): } self.assertDictContainsSubset(det_data, result) self.assertIn("refresh_token", result) + self.auth_handler.issue_access_token.assert_called_once_with( + user_id) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False -- cgit 1.5.1 From 3413f1e284593aa63723cdcd52f443d63771ef62 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 19 Jul 2016 10:21:42 +0100 Subject: Type annotations Add some type annotations to help PyCharm (in particular) to figure out the types of a bunch of things. --- synapse/handlers/_base.py | 4 ++++ synapse/handlers/auth.py | 4 ++++ synapse/rest/client/v1/base.py | 4 ++++ synapse/rest/client/v1/register.py | 4 ++++ synapse/rest/client/v2_alpha/register.py | 9 +++++++++ synapse/server.pyi | 21 +++++++++++++++++++++ 6 files changed, 46 insertions(+) create mode 100644 synapse/server.pyi (limited to 'synapse/handlers') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d00685c389..6264aa0d9a 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -36,6 +36,10 @@ class BaseHandler(object): """ def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ self.store = hs.get_datastore() self.auth = hs.get_auth() self.notifier = hs.get_notifier() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index ce9bc18849..8f83923ddb 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -45,6 +45,10 @@ class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ super(AuthHandler, self).__init__(hs) self.checkers = { LoginType.PASSWORD: self._check_password_auth, diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 1c020b7e2c..96b49b01f2 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet): """ def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ self.hs = hs self.handlers = hs.get_handlers() self.builder_factory = hs.get_event_builder_factory() diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 8e1f1b7845..efe796c65f 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -52,6 +52,10 @@ class RegisterRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False) def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(RegisterRestServlet, self).__init__(hs) # sessions are stored as: # self.sessions = { diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 5db953a1e3..2722a58e3e 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -45,6 +45,10 @@ class RegisterRequestTokenRestServlet(RestServlet): PATTERNS = client_v2_patterns("/register/email/requestToken$") def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(RegisterRequestTokenRestServlet, self).__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler @@ -77,7 +81,12 @@ class RegisterRestServlet(RestServlet): PATTERNS = client_v2_patterns("/register$") def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(RegisterRestServlet, self).__init__() + self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() diff --git a/synapse/server.pyi b/synapse/server.pyi new file mode 100644 index 0000000000..902f725c06 --- /dev/null +++ b/synapse/server.pyi @@ -0,0 +1,21 @@ +import synapse.handlers +import synapse.handlers.auth +import synapse.handlers.device +import synapse.storage +import synapse.state + +class HomeServer(object): + def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: + pass + + def get_datastore(self) -> synapse.storage.DataStore: + pass + + def get_device_handler(self) -> synapse.handlers.device.DeviceHandler: + pass + + def get_handlers(self) -> synapse.handlers.Handlers: + pass + + def get_state_handler(self) -> synapse.state.StateHandler: + pass -- cgit 1.5.1 From 57dca356923f220026d31fbb58fcf37ae9b27c8e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 20 Jul 2016 13:25:06 +0100 Subject: Don't notify pusher pool for backfilled events --- synapse/handlers/federation.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 7622962d46..3f138daf17 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1118,11 +1118,12 @@ class FederationHandler(BaseHandler): backfilled=backfilled, ) - # this intentionally does not yield: we don't care about the result - # and don't need to wait for it. - preserve_fn(self.hs.get_pusherpool().on_new_notifications)( - event_stream_id, max_stream_id - ) + if not backfilled: + # this intentionally does not yield: we don't care about the result + # and don't need to wait for it. + preserve_fn(self.hs.get_pusherpool().on_new_notifications)( + event_stream_id, max_stream_id + ) defer.returnValue((context, event_stream_id, max_stream_id)) -- cgit 1.5.1 From bc8f265f0a8443e918b17a94f4b2fa319e70a21f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 20 Jul 2016 16:34:00 +0100 Subject: GET /devices endpoint implement a GET /devices endpoint which lists all of the user's devices. It also returns the last IP where we saw that device, so there is some dancing to fish that out of the user_ips table. --- synapse/handlers/device.py | 27 ++++++++ synapse/rest/__init__.py | 2 + synapse/rest/client/v2_alpha/_base.py | 13 ++-- synapse/rest/client/v2_alpha/devices.py | 51 ++++++++++++++ synapse/storage/client_ips.py | 72 ++++++++++++++++++++ synapse/storage/devices.py | 22 +++++- synapse/storage/schema/delta/33/user_ips_index.sql | 16 +++++ tests/handlers/test_device.py | 78 ++++++++++++++++++---- tests/storage/test_client_ips.py | 62 +++++++++++++++++ tests/storage/test_devices.py | 71 ++++++++++++++++++++ 10 files changed, 397 insertions(+), 17 deletions(-) create mode 100644 synapse/rest/client/v2_alpha/devices.py create mode 100644 synapse/storage/schema/delta/33/user_ips_index.sql create mode 100644 tests/storage/test_client_ips.py create mode 100644 tests/storage/test_devices.py (limited to 'synapse/handlers') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 8d7d9874f8..6bbbf59e52 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -69,3 +69,30 @@ class DeviceHandler(BaseHandler): attempts += 1 raise StoreError(500, "Couldn't generate a device ID.") + + @defer.inlineCallbacks + def get_devices_by_user(self, user_id): + """ + Retrieve the given user's devices + + Args: + user_id (str): + Returns: + defer.Deferred: dict[str, dict[str, X]]: map from device_id to + info on the device + """ + + devices = yield self.store.get_devices_by_user(user_id) + + ips = yield self.store.get_last_client_ip_by_device( + devices=((user_id, device_id) for device_id in devices.keys()) + ) + + for device_id in devices.keys(): + ip = ips.get((user_id, device_id), {}) + devices[device_id].update({ + "last_seen_ts": ip.get("last_seen"), + "last_seen_ip": ip.get("ip"), + }) + + defer.returnValue(devices) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 8b223e032b..14227f1cdb 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -46,6 +46,7 @@ from synapse.rest.client.v2_alpha import ( account_data, report_event, openid, + devices, ) from synapse.http.server import JsonResource @@ -90,3 +91,4 @@ class ClientRestResource(JsonResource): account_data.register_servlets(hs, client_resource) report_event.register_servlets(hs, client_resource) openid.register_servlets(hs, client_resource) + devices.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index b6faa2b0e6..20e765f48f 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -25,7 +25,9 @@ import logging logger = logging.getLogger(__name__) -def client_v2_patterns(path_regex, releases=(0,)): +def client_v2_patterns(path_regex, releases=(0,), + v2_alpha=True, + unstable=True): """Creates a regex compiled client path with the correct client path prefix. @@ -35,9 +37,12 @@ def client_v2_patterns(path_regex, releases=(0,)): Returns: SRE_Pattern """ - patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)] - unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") - patterns.append(re.compile("^" + unstable_prefix + path_regex)) + patterns = [] + if v2_alpha: + patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)) + if unstable: + unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") + patterns.append(re.compile("^" + unstable_prefix + path_regex)) for release in releases: new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) patterns.append(re.compile("^" + new_prefix + path_regex)) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py new file mode 100644 index 0000000000..5cf8bd1afa --- /dev/null +++ b/synapse/rest/client/v2_alpha/devices.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.http.servlet import RestServlet + +from ._base import client_v2_patterns + +import logging + + +logger = logging.getLogger(__name__) + + +class DevicesRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(DevicesRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + devices = yield self.device_handler.get_devices_by_user( + requester.user.to_string() + ) + defer.returnValue((200, {"devices": devices})) + + +def register_servlets(hs, http_server): + DevicesRestServlet(hs).register(http_server) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index a90990e006..07161496ca 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -13,10 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from ._base import SQLBaseStore, Cache from twisted.internet import defer +logger = logging.getLogger(__name__) # Number of msec of granularity to store the user IP 'last seen' time. Smaller # times give more inserts into the database even for readonly API hits @@ -66,3 +69,72 @@ class ClientIpStore(SQLBaseStore): desc="insert_client_ip", lock=False, ) + + @defer.inlineCallbacks + def get_last_client_ip_by_device(self, devices): + """For each device_id listed, give the user_ip it was last seen on + + Args: + devices (iterable[(str, str)]): list of (user_id, device_id) pairs + + Returns: + defer.Deferred: resolves to a dict, where the keys + are (user_id, device_id) tuples. The values are also dicts, with + keys giving the column names + """ + + res = yield self.runInteraction( + "get_last_client_ip_by_device", + self._get_last_client_ip_by_device_txn, + retcols=( + "user_id", + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ), + devices=devices + ) + + ret = {(d["user_id"], d["device_id"]): d for d in res} + defer.returnValue(ret) + + @classmethod + def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols): + def where_clause_for_device(d): + return + + where_clauses = [] + bindings = [] + for (user_id, device_id) in devices: + if device_id is None: + where_clauses.append("(user_id = ? AND device_id IS NULL)") + bindings.extend((user_id, )) + else: + where_clauses.append("(user_id = ? AND device_id = ?)") + bindings.extend((user_id, device_id)) + + inner_select = ( + "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips " + "WHERE %(where)s " + "GROUP BY user_id, device_id" + ) % { + "where": " OR ".join(where_clauses), + } + + sql = ( + "SELECT %(retcols)s FROM user_ips " + "JOIN (%(inner_select)s) ips ON" + " user_ips.last_seen = ips.mls AND" + " user_ips.user_id = ips.user_id AND" + " (user_ips.device_id = ips.device_id OR" + " (user_ips.device_id IS NULL AND ips.device_id IS NULL)" + " )" + ) % { + "retcols": ",".join("user_ips." + c for c in retcols), + "inner_select": inner_select, + } + + txn.execute(sql, bindings) + return cls.cursor_to_dict(txn) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 9065e96d28..1cc6e07f2b 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -65,7 +65,7 @@ class DeviceStore(SQLBaseStore): user_id (str): The ID of the user which owns the device device_id (str): The ID of the device to retrieve Returns: - defer.Deferred for a namedtuple containing the device information + defer.Deferred for a dict containing the device information Raises: StoreError: if the device is not found """ @@ -75,3 +75,23 @@ class DeviceStore(SQLBaseStore): retcols=("user_id", "device_id", "display_name"), desc="get_device", ) + + @defer.inlineCallbacks + def get_devices_by_user(self, user_id): + """Retrieve all of a user's registered devices. + + Args: + user_id (str): + Returns: + defer.Deferred: resolves to a dict from device_id to a dict + containing "device_id", "user_id" and "display_name" for each + device. + """ + devices = yield self._simple_select_list( + table="devices", + keyvalues={"user_id": user_id}, + retcols=("user_id", "device_id", "display_name"), + desc="get_devices_by_user" + ) + + defer.returnValue({d["device_id"]: d for d in devices}) diff --git a/synapse/storage/schema/delta/33/user_ips_index.sql b/synapse/storage/schema/delta/33/user_ips_index.sql new file mode 100644 index 0000000000..8a05677d42 --- /dev/null +++ b/synapse/storage/schema/delta/33/user_ips_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX user_ips_device_id ON user_ips(user_id, device_id, last_seen); diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index cc6512ccc7..c2e12135d6 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -12,25 +12,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from synapse import types from twisted.internet import defer -from synapse.handlers.device import DeviceHandler -from tests import unittest -from tests.utils import setup_test_homeserver - - -class DeviceHandlers(object): - def __init__(self, hs): - self.device_handler = DeviceHandler(hs) +import synapse.handlers.device +import synapse.storage +from tests import unittest, utils class DeviceTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(DeviceTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.DataStore + self.handler = None # type: device.DeviceHandler + self.clock = None # type: utils.MockClock + @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver(handlers=None) - self.hs.handlers = handlers = DeviceHandlers(self.hs) - self.handler = handlers.device_handler + hs = yield utils.setup_test_homeserver(handlers=None) + self.handler = synapse.handlers.device.DeviceHandler(hs) + self.store = hs.get_datastore() + self.clock = hs.get_clock() @defer.inlineCallbacks def test_device_is_created_if_doesnt_exist(self): @@ -73,3 +75,55 @@ class DeviceTestCase(unittest.TestCase): dev = yield self.handler.store.get_device("theresa", device_id) self.assertEqual(dev["display_name"], "display") + + @defer.inlineCallbacks + def test_get_devices_by_user(self): + # check this works for both devices which have a recorded client_ip, + # and those which don't. + user1 = "@boris:aaa" + user2 = "@theresa:bbb" + yield self._record_user(user1, "xyz", "display 0") + yield self._record_user(user1, "fco", "display 1", "token1", "ip1") + yield self._record_user(user1, "abc", "display 2", "token2", "ip2") + yield self._record_user(user1, "abc", "display 2", "token3", "ip3") + + yield self._record_user(user2, "def", "dispkay", "token4", "ip4") + + res = yield self.handler.get_devices_by_user(user1) + self.assertEqual(3, len(res.keys())) + self.assertDictContainsSubset({ + "user_id": user1, + "device_id": "xyz", + "display_name": "display 0", + "last_seen_ip": None, + "last_seen_ts": None, + }, res["xyz"]) + self.assertDictContainsSubset({ + "user_id": user1, + "device_id": "fco", + "display_name": "display 1", + "last_seen_ip": "ip1", + "last_seen_ts": 1000000, + }, res["fco"]) + self.assertDictContainsSubset({ + "user_id": user1, + "device_id": "abc", + "display_name": "display 2", + "last_seen_ip": "ip3", + "last_seen_ts": 3000000, + }, res["abc"]) + + @defer.inlineCallbacks + def _record_user(self, user_id, device_id, display_name, + access_token=None, ip=None): + device_id = yield self.handler.check_device_registered( + user_id=user_id, + device_id=device_id, + initial_device_display_name=display_name + ) + + if ip is not None: + yield self.store.insert_client_ip( + types.UserID.from_string(user_id), + access_token, ip, "user_agent", device_id) + self.clock.advance_time(1000) \ No newline at end of file diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py new file mode 100644 index 0000000000..1f0c0e7c37 --- /dev/null +++ b/tests/storage/test_client_ips.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import synapse.server +import synapse.storage +import synapse.types +import tests.unittest +import tests.utils + + +class ClientIpStoreTestCase(tests.unittest.TestCase): + def __init__(self, *args, **kwargs): + super(ClientIpStoreTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.DataStore + self.clock = None # type: tests.utils.MockClock + + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def test_insert_new_client_ip(self): + self.clock.now = 12345678 + user_id = "@user:id" + yield self.store.insert_client_ip( + synapse.types.UserID.from_string(user_id), + "access_token", "ip", "user_agent", "device_id", + ) + + # deliberately use an iterable here to make sure that the lookup + # method doesn't iterate it twice + device_list = iter(((user_id, "device_id"),)) + result = yield self.store.get_last_client_ip_by_device(device_list) + + r = result[(user_id, "device_id")] + self.assertDictContainsSubset( + { + "user_id": user_id, + "device_id": "device_id", + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "last_seen": 12345678000, + }, + r + ) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py new file mode 100644 index 0000000000..d3e9d97a9a --- /dev/null +++ b/tests/storage/test_devices.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import synapse.server +import synapse.types +import tests.unittest +import tests.utils + + +class DeviceStoreTestCase(tests.unittest.TestCase): + def __init__(self, *args, **kwargs): + super(DeviceStoreTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.DataStore + + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_store_new_device(self): + yield self.store.store_device( + "user_id", "device_id", "display_name" + ) + + res = yield self.store.get_device("user_id", "device_id") + self.assertDictContainsSubset({ + "user_id": "user_id", + "device_id": "device_id", + "display_name": "display_name", + }, res) + + @defer.inlineCallbacks + def test_get_devices_by_user(self): + yield self.store.store_device( + "user_id", "device1", "display_name 1" + ) + yield self.store.store_device( + "user_id", "device2", "display_name 2" + ) + yield self.store.store_device( + "user_id2", "device3", "display_name 3" + ) + + res = yield self.store.get_devices_by_user("user_id") + self.assertEqual(2, len(res.keys())) + self.assertDictContainsSubset({ + "user_id": "user_id", + "device_id": "device1", + "display_name": "display_name 1", + }, res["device1"]) + self.assertDictContainsSubset({ + "user_id": "user_id", + "device_id": "device2", + "display_name": "display_name 2", + }, res["device2"]) -- cgit 1.5.1 From 248e6770ca0faadf574cfd62f72d8e200cb5b57a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 21 Jul 2016 10:30:12 +0100 Subject: Cache federation state responses --- synapse/federation/federation_server.py | 66 ++++++++++++++++++++++----------- synapse/handlers/federation.py | 7 +--- synapse/handlers/room.py | 4 +- synapse/handlers/sync.py | 2 +- synapse/util/caches/response_cache.py | 13 ++++++- 5 files changed, 60 insertions(+), 32 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 85f5e752fe..d15c7e1b40 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -21,10 +21,11 @@ from .units import Transaction, Edu from synapse.util.async import Linearizer from synapse.util.logutils import log_function +from synapse.util.caches.response_cache import ResponseCache from synapse.events import FrozenEvent import synapse.metrics -from synapse.api.errors import FederationError, SynapseError +from synapse.api.errors import AuthError, FederationError, SynapseError from synapse.crypto.event_signing import compute_event_signature @@ -48,9 +49,15 @@ class FederationServer(FederationBase): def __init__(self, hs): super(FederationServer, self).__init__(hs) + self.auth = hs.get_auth() + self._room_pdu_linearizer = Linearizer() self._server_linearizer = Linearizer() + # We cache responses to state queries, as they take a while and often + # come in waves. + self._state_resp_cache = ResponseCache(hs, timeout_ms=30000) + def set_handler(self, handler): """Sets the handler that the replication layer will use to communicate receipt of new PDUs from other home servers. The required methods are @@ -188,28 +195,45 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function def on_context_state_request(self, origin, room_id, event_id): - with (yield self._server_linearizer.queue((origin, room_id))): - if event_id: - pdus = yield self.handler.get_state_for_pdu( - origin, room_id, event_id, - ) - auth_chain = yield self.store.get_auth_chain( - [pdu.event_id for pdu in pdus] + if not event_id: + raise NotImplementedError("Specify an event") + + in_room = yield self.auth.check_host_in_room(room_id, origin) + if not in_room: + raise AuthError(403, "Host not in room.") + + result = self._state_resp_cache.get((room_id, event_id)) + if not result: + with (yield self._server_linearizer.queue((origin, room_id))): + resp = yield self.response_cache.set( + (room_id, event_id), + self._on_context_state_request_compute(room_id, event_id) ) + else: + resp = yield result - for event in auth_chain: - # We sign these again because there was a bug where we - # incorrectly signed things the first time round - if self.hs.is_mine_id(event.event_id): - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] - ) - ) - else: - raise NotImplementedError("Specify an event") + defer.returnValue((200, resp)) + + @defer.inlineCallbacks + def _on_context_state_request_compute(self, room_id, event_id): + pdus = yield self.handler.get_state_for_pdu( + room_id, event_id, + ) + auth_chain = yield self.store.get_auth_chain( + [pdu.event_id for pdu in pdus] + ) + + for event in auth_chain: + # We sign these again because there was a bug where we + # incorrectly signed things the first time round + if self.hs.is_mine_id(event.event_id): + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) + ) defer.returnValue((200, { "pdus": [pdu.get_pdu_json() for pdu in pdus], diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3f138daf17..fcad41d7b6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -991,14 +991,9 @@ class FederationHandler(BaseHandler): defer.returnValue(None) @defer.inlineCallbacks - def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True): + def get_state_for_pdu(self, room_id, event_id): yield run_on_reactor() - if do_auth: - in_room = yield self.auth.check_host_in_room(room_id, origin) - if not in_room: - raise AuthError(403, "Host not in room.") - state_groups = yield self.store.get_state_groups( room_id, [event_id] ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ae44c7a556..bf6b1c1535 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -345,8 +345,8 @@ class RoomCreationHandler(BaseHandler): class RoomListHandler(BaseHandler): def __init__(self, hs): super(RoomListHandler, self).__init__(hs) - self.response_cache = ResponseCache() - self.remote_list_request_cache = ResponseCache() + self.response_cache = ResponseCache(hs) + self.remote_list_request_cache = ResponseCache(hs) self.remote_list_cache = {} self.fetch_looping_call = hs.get_clock().looping_call( self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index be26a491ff..0ee4ebe504 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -138,7 +138,7 @@ class SyncHandler(object): self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() - self.response_cache = ResponseCache() + self.response_cache = ResponseCache(hs) def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 36686b479e..00af539880 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -24,9 +24,12 @@ class ResponseCache(object): used rather than trying to compute a new response. """ - def __init__(self): + def __init__(self, hs, timeout_ms=0): self.pending_result_cache = {} # Requests that haven't finished yet. + self.clock = hs.get_clock() + self.timeout_sec = timeout_ms / 1000. + def get(self, key): result = self.pending_result_cache.get(key) if result is not None: @@ -39,7 +42,13 @@ class ResponseCache(object): self.pending_result_cache[key] = result def remove(r): - self.pending_result_cache.pop(key, None) + if self.timeout_sec: + self.clock.call_later( + self.timeout_sec, + self.pending_result_cache.pop, key, None, + ) + else: + self.pending_result_cache.pop(key, None) return r result.addBoth(remove) -- cgit 1.5.1 From 406f7aa0f6ca7433e52433485824e80b79930498 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 20 Jul 2016 17:58:44 +0100 Subject: Implement GET /device/{deviceId} --- synapse/handlers/device.py | 46 ++++++++++++++++++++++++++------- synapse/rest/client/v2_alpha/devices.py | 25 ++++++++++++++++++ tests/handlers/test_device.py | 37 +++++++++++++++++++------- 3 files changed, 89 insertions(+), 19 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 6bbbf59e52..3c88be0679 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -12,7 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.errors import StoreError + +from synapse.api import errors from synapse.util import stringutils from twisted.internet import defer from ._base import BaseHandler @@ -65,10 +66,10 @@ class DeviceHandler(BaseHandler): ignore_if_known=False, ) defer.returnValue(device_id) - except StoreError: + except errors.StoreError: attempts += 1 - raise StoreError(500, "Couldn't generate a device ID.") + raise errors.StoreError(500, "Couldn't generate a device ID.") @defer.inlineCallbacks def get_devices_by_user(self, user_id): @@ -88,11 +89,38 @@ class DeviceHandler(BaseHandler): devices=((user_id, device_id) for device_id in devices.keys()) ) - for device_id in devices.keys(): - ip = ips.get((user_id, device_id), {}) - devices[device_id].update({ - "last_seen_ts": ip.get("last_seen"), - "last_seen_ip": ip.get("ip"), - }) + for device in devices.values(): + _update_device_from_client_ips(device, ips) defer.returnValue(devices) + + @defer.inlineCallbacks + def get_device(self, user_id, device_id): + """ Retrieve the given device + + Args: + user_id (str): + device_id (str) + + Returns: + defer.Deferred: dict[str, X]: info on the device + Raises: + errors.NotFoundError: if the device was not found + """ + try: + device = yield self.store.get_device(user_id, device_id) + except errors.StoreError, e: + raise errors.NotFoundError + ips = yield self.store.get_last_client_ip_by_device( + devices=((user_id, device_id),) + ) + _update_device_from_client_ips(device, ips) + defer.returnValue(device) + + +def _update_device_from_client_ips(device, client_ips): + ip = client_ips.get((device["user_id"], device["device_id"]), {}) + device.update({ + "last_seen_ts": ip.get("last_seen"), + "last_seen_ip": ip.get("ip"), + }) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 5cf8bd1afa..8b9ab4f674 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -47,5 +47,30 @@ class DevicesRestServlet(RestServlet): defer.returnValue((200, {"devices": devices})) +class DeviceRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/devices/(?P[^/]*)$", + releases=[], v2_alpha=False) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(DeviceRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def on_GET(self, request, device_id): + requester = yield self.auth.get_user_by_req(request) + device = yield self.device_handler.get_device( + requester.user.to_string(), + device_id, + ) + defer.returnValue((200, device)) + + def register_servlets(hs, http_server): DevicesRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index b05aa9bb55..73f09874d8 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -19,6 +19,8 @@ import synapse.handlers.device import synapse.storage from tests import unittest, utils +user1 = "@boris:aaa" +user2 = "@theresa:bbb" class DeviceTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): @@ -78,16 +80,7 @@ class DeviceTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_devices_by_user(self): - # check this works for both devices which have a recorded client_ip, - # and those which don't. - user1 = "@boris:aaa" - user2 = "@theresa:bbb" - yield self._record_user(user1, "xyz", "display 0") - yield self._record_user(user1, "fco", "display 1", "token1", "ip1") - yield self._record_user(user1, "abc", "display 2", "token2", "ip2") - yield self._record_user(user1, "abc", "display 2", "token3", "ip3") - - yield self._record_user(user2, "def", "dispkay", "token4", "ip4") + yield self._record_users() res = yield self.handler.get_devices_by_user(user1) self.assertEqual(3, len(res.keys())) @@ -113,6 +106,30 @@ class DeviceTestCase(unittest.TestCase): "last_seen_ts": 3000000, }, res["abc"]) + @defer.inlineCallbacks + def test_get_device(self): + yield self._record_users() + + res = yield self.handler.get_device(user1, "abc") + self.assertDictContainsSubset({ + "user_id": user1, + "device_id": "abc", + "display_name": "display 2", + "last_seen_ip": "ip3", + "last_seen_ts": 3000000, + }, res) + + @defer.inlineCallbacks + def _record_users(self): + # check this works for both devices which have a recorded client_ip, + # and those which don't. + yield self._record_user(user1, "xyz", "display 0") + yield self._record_user(user1, "fco", "display 1", "token1", "ip1") + yield self._record_user(user1, "abc", "display 2", "token2", "ip2") + yield self._record_user(user1, "abc", "display 2", "token3", "ip3") + + yield self._record_user(user2, "def", "dispkay", "token4", "ip4") + @defer.inlineCallbacks def _record_user(self, user_id, device_id, display_name, access_token=None, ip=None): -- cgit 1.5.1 From 1c3c202b969d6a7e5e4af2b2dca370f053b92c9f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 21 Jul 2016 13:15:15 +0100 Subject: Fix PEP8 errors --- synapse/handlers/device.py | 2 +- tests/handlers/test_device.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 3c88be0679..110f5fbb5c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -109,7 +109,7 @@ class DeviceHandler(BaseHandler): """ try: device = yield self.store.get_device(user_id, device_id) - except errors.StoreError, e: + except errors.StoreError: raise errors.NotFoundError ips = yield self.store.get_last_client_ip_by_device( devices=((user_id, device_id),) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 73f09874d8..87c3c75aea 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -22,6 +22,7 @@ from tests import unittest, utils user1 = "@boris:aaa" user2 = "@theresa:bbb" + class DeviceTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(DeviceTestCase, self).__init__(*args, **kwargs) -- cgit 1.5.1 From 55abbe1850efff95efe9935873b666e5fc4bf0e9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 21 Jul 2016 15:55:13 +0100 Subject: make /devices return a list Turns out I specced this to return a list of devices rather than a dict of them --- synapse/handlers/device.py | 10 +++++----- tests/handlers/test_device.py | 11 +++++++---- 2 files changed, 12 insertions(+), 9 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 110f5fbb5c..1f9e15c33c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -79,17 +79,17 @@ class DeviceHandler(BaseHandler): Args: user_id (str): Returns: - defer.Deferred: dict[str, dict[str, X]]: map from device_id to - info on the device + defer.Deferred: list[dict[str, X]]: info on each device """ - devices = yield self.store.get_devices_by_user(user_id) + device_map = yield self.store.get_devices_by_user(user_id) ips = yield self.store.get_last_client_ip_by_device( - devices=((user_id, device_id) for device_id in devices.keys()) + devices=((user_id, device_id) for device_id in device_map.keys()) ) - for device in devices.values(): + devices = device_map.values() + for device in devices: _update_device_from_client_ips(device, ips) defer.returnValue(devices) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 87c3c75aea..331aa13fed 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -84,28 +84,31 @@ class DeviceTestCase(unittest.TestCase): yield self._record_users() res = yield self.handler.get_devices_by_user(user1) - self.assertEqual(3, len(res.keys())) + self.assertEqual(3, len(res)) + device_map = { + d["device_id"]: d for d in res + } self.assertDictContainsSubset({ "user_id": user1, "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, "last_seen_ts": None, - }, res["xyz"]) + }, device_map["xyz"]) self.assertDictContainsSubset({ "user_id": user1, "device_id": "fco", "display_name": "display 1", "last_seen_ip": "ip1", "last_seen_ts": 1000000, - }, res["fco"]) + }, device_map["fco"]) self.assertDictContainsSubset({ "user_id": user1, "device_id": "abc", "display_name": "display 2", "last_seen_ip": "ip3", "last_seen_ts": 3000000, - }, res["abc"]) + }, device_map["abc"]) @defer.inlineCallbacks def test_get_device(self): -- cgit 1.5.1 From dad2da7e54a4f0e92185e4f8553fb51b037c0bd3 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 22 Jul 2016 17:00:56 +0100 Subject: Log the hostname the reCAPTCHA was completed on This could be useful information to have in the logs. Also comment about how & why we don't verify the hostname. --- synapse/handlers/auth.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 8f83923ddb..6fff7e7d03 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -279,8 +279,17 @@ class AuthHandler(BaseHandler): data = pde.response resp_body = simplejson.loads(data) - if 'success' in resp_body and resp_body['success']: - defer.returnValue(True) + if 'success' in resp_body: + # Note that we do NOT check the hostname here: we explicitly + # intend the CAPTCHA to be presented by whatever client the + # user is using, we just care that they have completed a CAPTCHA. + logger.info( + "%s reCAPTCHA from hostname %s", + "Successful" if resp_body['success'] else "Failed", + resp_body['hostname'] + ) + if resp_body['success']: + defer.returnValue(True) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) @defer.inlineCallbacks -- cgit 1.5.1 From 7ed58bb3476c4a18a9af97b8ee3358dac00098eb Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 22 Jul 2016 17:18:50 +0100 Subject: Use get to avoid KeyErrors --- synapse/handlers/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 6fff7e7d03..d5d2072436 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -286,7 +286,7 @@ class AuthHandler(BaseHandler): logger.info( "%s reCAPTCHA from hostname %s", "Successful" if resp_body['success'] else "Failed", - resp_body['hostname'] + resp_body.get('hostname') ) if resp_body['success']: defer.returnValue(True) -- cgit 1.5.1 From 436bffd15fb8382a0d2dddd3c6f7a077ba751da2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 22 Jul 2016 14:52:53 +0100 Subject: Implement deleting devices --- synapse/handlers/auth.py | 22 ++++++++++++++++-- synapse/handlers/device.py | 27 +++++++++++++++++++++- synapse/rest/client/v1/login.py | 13 ++++++++--- synapse/rest/client/v2_alpha/devices.py | 14 +++++++++++ synapse/rest/client/v2_alpha/register.py | 10 ++++---- synapse/storage/devices.py | 15 ++++++++++++ synapse/storage/registration.py | 26 +++++++++++++++++---- .../schema/delta/33/access_tokens_device_index.sql | 17 ++++++++++++++ .../schema/delta/33/refreshtoken_device_index.sql | 17 ++++++++++++++ tests/handlers/test_device.py | 22 ++++++++++++++++-- tests/rest/client/v2_alpha/test_register.py | 14 +++++++---- 11 files changed, 176 insertions(+), 21 deletions(-) create mode 100644 synapse/storage/schema/delta/33/access_tokens_device_index.sql create mode 100644 synapse/storage/schema/delta/33/refreshtoken_device_index.sql (limited to 'synapse/handlers') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d5d2072436..2e138f328f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -77,6 +77,7 @@ class AuthHandler(BaseHandler): self.ldap_bind_password = hs.config.ldap_bind_password self.hs = hs # FIXME better possibility to access registrationHandler later? + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -374,7 +375,8 @@ class AuthHandler(BaseHandler): return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id, device_id=None): + def get_login_tuple_for_user_id(self, user_id, device_id=None, + initial_display_name=None): """ Gets login tuple for the user with the given user ID. @@ -383,9 +385,15 @@ class AuthHandler(BaseHandler): The user is assumed to have been authenticated by some other machanism (e.g. CAS), and the user_id converted to the canonical case. + The device will be recorded in the table if it is not there already. + Args: user_id (str): canonical User ID - device_id (str): the device ID to associate with the access token + device_id (str|None): the device ID to associate with the tokens. + None to leave the tokens unassociated with a device (deprecated: + we should always have a device ID) + initial_display_name (str): display name to associate with the + device if it needs re-registering Returns: A tuple of: The access token for the user's session. @@ -397,6 +405,16 @@ class AuthHandler(BaseHandler): logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) refresh_token = yield self.issue_refresh_token(user_id, device_id) + + # the device *should* have been registered before we got here; however, + # it's possible we raced against a DELETE operation. The thing we + # really don't want is active access_tokens without a record of the + # device, so we double-check it here. + if device_id is not None: + yield self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 1f9e15c33c..a7a192e1c9 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -100,7 +100,7 @@ class DeviceHandler(BaseHandler): Args: user_id (str): - device_id (str) + device_id (str): Returns: defer.Deferred: dict[str, X]: info on the device @@ -117,6 +117,31 @@ class DeviceHandler(BaseHandler): _update_device_from_client_ips(device, ips) defer.returnValue(device) + @defer.inlineCallbacks + def delete_device(self, user_id, device_id): + """ Delete the given device + + Args: + user_id (str): + device_id (str): + + Returns: + defer.Deferred: + """ + + try: + yield self.store.delete_device(user_id, device_id) + except errors.StoreError, e: + if e.code == 404: + # no match + pass + else: + raise + + yield self.store.user_delete_access_tokens(user_id, + device_id=device_id) + + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index e8b791519c..92fcae674a 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -152,7 +152,10 @@ class LoginRestServlet(ClientV1RestServlet): ) device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { "user_id": user_id, # may have changed @@ -173,7 +176,10 @@ class LoginRestServlet(ClientV1RestServlet): ) device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { "user_id": user_id, # may have changed @@ -262,7 +268,8 @@ class LoginRestServlet(ClientV1RestServlet): ) access_token, refresh_token = ( yield auth_handler.get_login_tuple_for_user_id( - registered_user_id, device_id + registered_user_id, device_id, + login_submission.get("initial_device_display_name") ) ) result = { diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 8b9ab4f674..30ef8b3da9 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -70,6 +70,20 @@ class DeviceRestServlet(RestServlet): ) defer.returnValue((200, device)) + @defer.inlineCallbacks + def on_DELETE(self, request, device_id): + # XXX: it's not completely obvious we want to expose this endpoint. + # It allows the client to delete access tokens, which feels like a + # thing which merits extra auth. But if we want to do the interactive- + # auth dance, we should really make it possible to delete more than one + # device at a time. + requester = yield self.auth.get_user_by_req(request) + yield self.device_handler.delete_device( + requester.user.to_string(), + device_id, + ) + defer.returnValue((200, {})) + def register_servlets(hs, http_server): DevicesRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index c8c9395fc6..9f599ea8bb 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -374,13 +374,13 @@ class RegisterRestServlet(RestServlet): """ device_id = yield self._register_device(user_id, params) - access_token = yield self.auth_handler.issue_access_token( - user_id, device_id=device_id + access_token, refresh_token = ( + yield self.auth_handler.get_login_tuple_for_user_id( + user_id, device_id=device_id, + initial_display_name=params.get("initial_device_display_name") + ) ) - refresh_token = yield self.auth_handler.issue_refresh_token( - user_id, device_id=device_id - ) defer.returnValue({ "user_id": user_id, "access_token": access_token, diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 1cc6e07f2b..4689980f80 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -76,6 +76,21 @@ class DeviceStore(SQLBaseStore): desc="get_device", ) + def delete_device(self, user_id, device_id): + """Delete a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred + """ + return self._simple_delete_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="delete_device", + ) + @defer.inlineCallbacks def get_devices_by_user(self, user_id): """Retrieve all of a user's registered devices. diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 9a92b35361..935e82bf7a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -18,18 +18,31 @@ import re from twisted.internet import defer from synapse.api.errors import StoreError, Codes - -from ._base import SQLBaseStore +from synapse.storage import background_updates from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -class RegistrationStore(SQLBaseStore): +class RegistrationStore(background_updates.BackgroundUpdateStore): def __init__(self, hs): super(RegistrationStore, self).__init__(hs) self.clock = hs.get_clock() + self.register_background_index_update( + "access_tokens_device_index", + index_name="access_tokens_device_id", + table="access_tokens", + columns=["user_id", "device_id"], + ) + + self.register_background_index_update( + "refresh_tokens_device_index", + index_name="refresh_tokens_device_id", + table="refresh_tokens", + columns=["user_id", "device_id"], + ) + @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. @@ -238,11 +251,16 @@ class RegistrationStore(SQLBaseStore): self.get_user_by_id.invalidate((user_id,)) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id, except_token_ids=[]): + def user_delete_access_tokens(self, user_id, except_token_ids=[], + device_id=None): def f(txn): sql = "SELECT token FROM access_tokens WHERE user_id = ?" clauses = [user_id] + if device_id is not None: + sql += " AND device_id = ?" + clauses.append(device_id) + if except_token_ids: sql += " AND id NOT IN (%s)" % ( ",".join(["?" for _ in except_token_ids]), diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/schema/delta/33/access_tokens_device_index.sql new file mode 100644 index 0000000000..61ad3fe3e8 --- /dev/null +++ b/synapse/storage/schema/delta/33/access_tokens_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('access_tokens_device_index', '{}'); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql new file mode 100644 index 0000000000..bb225dafbf --- /dev/null +++ b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('refresh_tokens_device_index', '{}'); diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 331aa13fed..214e722eb3 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -12,11 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse import types + from twisted.internet import defer +import synapse.api.errors import synapse.handlers.device + import synapse.storage +from synapse import types from tests import unittest, utils user1 = "@boris:aaa" @@ -27,7 +30,7 @@ class DeviceTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(DeviceTestCase, self).__init__(*args, **kwargs) self.store = None # type: synapse.storage.DataStore - self.handler = None # type: device.DeviceHandler + self.handler = None # type: synapse.handlers.device.DeviceHandler self.clock = None # type: utils.MockClock @defer.inlineCallbacks @@ -123,6 +126,21 @@ class DeviceTestCase(unittest.TestCase): "last_seen_ts": 3000000, }, res) + @defer.inlineCallbacks + def test_delete_device(self): + yield self._record_users() + + # delete the device + yield self.handler.delete_device(user1, "abc") + + # check the device was deleted + with self.assertRaises(synapse.api.errors.NotFoundError): + yield self.handler.get_device(user1, "abc") + + # we'd like to check the access token was invalidated, but that's a + # bit of a PITA. + + @defer.inlineCallbacks def _record_users(self): # check this works for both devices which have a recorded client_ip, diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 3bd7065e32..8ac56a1fb2 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -65,13 +65,16 @@ class RegisterRestServletTestCase(unittest.TestCase): self.registration_handler.appservice_register = Mock( return_value=user_id ) - self.auth_handler.issue_access_token = Mock(return_value=token) + self.auth_handler.get_login_tuple_for_user_id = Mock( + return_value=(token, "kermits_refresh_token") + ) (code, result) = yield self.servlet.on_POST(self.request) self.assertEquals(code, 200) det_data = { "user_id": user_id, "access_token": token, + "refresh_token": "kermits_refresh_token", "home_server": self.hs.hostname } self.assertDictContainsSubset(det_data, result) @@ -121,7 +124,9 @@ class RegisterRestServletTestCase(unittest.TestCase): "password": "monkey" }, None) self.registration_handler.register = Mock(return_value=(user_id, None)) - self.auth_handler.issue_access_token = Mock(return_value=token) + self.auth_handler.get_login_tuple_for_user_id = Mock( + return_value=(token, "kermits_refresh_token") + ) self.device_handler.check_device_registered = \ Mock(return_value=device_id) @@ -130,13 +135,14 @@ class RegisterRestServletTestCase(unittest.TestCase): det_data = { "user_id": user_id, "access_token": token, + "refresh_token": "kermits_refresh_token", "home_server": self.hs.hostname, "device_id": device_id, } self.assertDictContainsSubset(det_data, result) self.assertIn("refresh_token", result) - self.auth_handler.issue_access_token.assert_called_once_with( - user_id, device_id=device_id) + self.auth_handler.get_login_tuple_for_user_id( + user_id, device_id=device_id, initial_device_display_name=None) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False -- cgit 1.5.1 From 012b4c19132d57fdbc1b6b0e304eb60eaf19200f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 25 Jul 2016 17:51:24 +0100 Subject: Implement updating devices You can update the displayname of devices now. --- synapse/handlers/device.py | 24 ++++++++++++++++++++++ synapse/rest/client/v2_alpha/devices.py | 24 +++++++++++++++------- synapse/storage/devices.py | 27 ++++++++++++++++++++++++- tests/handlers/test_device.py | 16 +++++++++++++++ tests/storage/test_devices.py | 36 +++++++++++++++++++++++++++++++++ 5 files changed, 119 insertions(+), 8 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index a7a192e1c9..9e65d85e6d 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -141,6 +141,30 @@ class DeviceHandler(BaseHandler): yield self.store.user_delete_access_tokens(user_id, device_id=device_id) + @defer.inlineCallbacks + def update_device(self, user_id, device_id, content): + """ Update the given device + + Args: + user_id (str): + device_id (str): + content (dict): body of update request + + Returns: + defer.Deferred: + """ + + try: + yield self.store.update_device( + user_id, + device_id, + new_display_name=content.get("display_name") + ) + except errors.StoreError, e: + if e.code == 404: + raise errors.NotFoundError() + else: + raise def _update_device_from_client_ips(device, client_ips): diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 30ef8b3da9..8fbd3d3dfc 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -13,19 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import logging -from synapse.http.servlet import RestServlet +from twisted.internet import defer +from synapse.http import servlet from ._base import client_v2_patterns -import logging - - logger = logging.getLogger(__name__) -class DevicesRestServlet(RestServlet): +class DevicesRestServlet(servlet.RestServlet): PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) def __init__(self, hs): @@ -47,7 +45,7 @@ class DevicesRestServlet(RestServlet): defer.returnValue((200, {"devices": devices})) -class DeviceRestServlet(RestServlet): +class DeviceRestServlet(servlet.RestServlet): PATTERNS = client_v2_patterns("/devices/(?P[^/]*)$", releases=[], v2_alpha=False) @@ -84,6 +82,18 @@ class DeviceRestServlet(RestServlet): ) defer.returnValue((200, {})) + @defer.inlineCallbacks + def on_PUT(self, request, device_id): + requester = yield self.auth.get_user_by_req(request) + + body = servlet.parse_json_object_from_request(request) + yield self.device_handler.update_device( + requester.user.to_string(), + device_id, + body + ) + defer.returnValue((200, {})) + def register_servlets(hs, http_server): DevicesRestServlet(hs).register(http_server) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 4689980f80..afd6530cab 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -81,7 +81,7 @@ class DeviceStore(SQLBaseStore): Args: user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to retrieve + device_id (str): The ID of the device to delete Returns: defer.Deferred """ @@ -91,6 +91,31 @@ class DeviceStore(SQLBaseStore): desc="delete_device", ) + def update_device(self, user_id, device_id, new_display_name=None): + """Update a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to update + new_display_name (str|None): new displayname for device; None + to leave unchanged + Raises: + StoreError: if the device is not found + Returns: + defer.Deferred + """ + updates = {} + if new_display_name is not None: + updates["display_name"] = new_display_name + if not updates: + return defer.succeed(None) + return self._simple_update_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + updatevalues=updates, + desc="update_device", + ) + @defer.inlineCallbacks def get_devices_by_user(self, user_id): """Retrieve all of a user's registered devices. diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 214e722eb3..85a970a6c9 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -140,6 +140,22 @@ class DeviceTestCase(unittest.TestCase): # we'd like to check the access token was invalidated, but that's a # bit of a PITA. + @defer.inlineCallbacks + def test_update_device(self): + yield self._record_users() + + update = {"display_name": "new display"} + yield self.handler.update_device(user1, "abc", update) + + res = yield self.handler.get_device(user1, "abc") + self.assertEqual(res["display_name"], "new display") + + @defer.inlineCallbacks + def test_update_unknown_device(self): + update = {"display_name": "new_display"} + with self.assertRaises(synapse.api.errors.NotFoundError): + yield self.handler.update_device("user_id", "unknown_device_id", + update) @defer.inlineCallbacks def _record_users(self): diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index a6ce993375..f8725acea0 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -15,6 +15,7 @@ from twisted.internet import defer +import synapse.api.errors import tests.unittest import tests.utils @@ -67,3 +68,38 @@ class DeviceStoreTestCase(tests.unittest.TestCase): "device_id": "device2", "display_name": "display_name 2", }, res["device2"]) + + @defer.inlineCallbacks + def test_update_device(self): + yield self.store.store_device( + "user_id", "device_id", "display_name 1" + ) + + res = yield self.store.get_device("user_id", "device_id") + self.assertEqual("display_name 1", res["display_name"]) + + # do a no-op first + yield self.store.update_device( + "user_id", "device_id", + ) + res = yield self.store.get_device("user_id", "device_id") + self.assertEqual("display_name 1", res["display_name"]) + + # do the update + yield self.store.update_device( + "user_id", "device_id", + new_display_name="display_name 2", + ) + + # check it worked + res = yield self.store.get_device("user_id", "device_id") + self.assertEqual("display_name 2", res["display_name"]) + + @defer.inlineCallbacks + def test_update_unknown_device(self): + with self.assertRaises(synapse.api.errors.StoreError) as cm: + yield self.store.update_device( + "user_id", "unknown_device_id", + new_display_name="display_name 2", + ) + self.assertEqual(404, cm.exception.code) -- cgit 1.5.1 From 8e0249416643f20f0c4cd8f2e19cf45ea63289d3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 26 Jul 2016 11:09:47 +0100 Subject: Delete refresh tokens when deleting devices --- synapse/handlers/device.py | 6 ++-- synapse/storage/registration.py | 58 +++++++++++++++++++++++++++++--------- tests/storage/test_registration.py | 34 ++++++++++++++++++++++ 3 files changed, 83 insertions(+), 15 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 9e65d85e6d..eaead50800 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -138,8 +138,10 @@ class DeviceHandler(BaseHandler): else: raise - yield self.store.user_delete_access_tokens(user_id, - device_id=device_id) + yield self.store.user_delete_access_tokens( + user_id, device_id=device_id, + delete_refresh_tokens=True, + ) @defer.inlineCallbacks def update_device(self, user_id, device_id, content): diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 935e82bf7a..d9555e073a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -252,20 +252,36 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): @defer.inlineCallbacks def user_delete_access_tokens(self, user_id, except_token_ids=[], - device_id=None): - def f(txn): - sql = "SELECT token FROM access_tokens WHERE user_id = ?" + device_id=None, + delete_refresh_tokens=False): + """ + Invalidate access/refresh tokens belonging to a user + + Args: + user_id (str): ID of user the tokens belong to + except_token_ids (list[str]): list of access_tokens which should + *not* be deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + delete_refresh_tokens (bool): True to delete refresh tokens as + well as access tokens. + Returns: + defer.Deferred: + """ + def f(txn, table, except_tokens, call_after_delete): + sql = "SELECT token FROM %s WHERE user_id = ?" % table clauses = [user_id] if device_id is not None: sql += " AND device_id = ?" clauses.append(device_id) - if except_token_ids: + if except_tokens: sql += " AND id NOT IN (%s)" % ( - ",".join(["?" for _ in except_token_ids]), + ",".join(["?" for _ in except_tokens]), ) - clauses += except_token_ids + clauses += except_tokens txn.execute(sql, clauses) @@ -274,16 +290,33 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): n = 100 chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] for chunk in chunks: - for row in chunk: - txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) + if call_after_delete: + for row in chunk: + txn.call_after(call_after_delete, (row[0],)) txn.execute( - "DELETE FROM access_tokens WHERE token in (%s)" % ( + "DELETE FROM %s WHERE token in (%s)" % ( + table, ",".join(["?" for _ in chunk]), ), [r[0] for r in chunk] ) - yield self.runInteraction("user_delete_access_tokens", f) + # delete refresh tokens first, to stop new access tokens being + # allocated while our backs are turned + if delete_refresh_tokens: + yield self.runInteraction( + "user_delete_access_tokens", f, + table="refresh_tokens", + except_tokens=[], + call_after_delete=None, + ) + + yield self.runInteraction( + "user_delete_access_tokens", f, + table="access_tokens", + except_tokens=except_token_ids, + call_after_delete=self.get_user_by_access_token.invalidate, + ) def delete_access_token(self, access_token): def f(txn): @@ -306,9 +339,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): Args: token (str): The access token of a user. Returns: - dict: Including the name (user_id) and the ID of their access token. - Raises: - StoreError if no user was found. + defer.Deferred: None, if the token did not match, ootherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`. """ return self.runInteraction( "get_user_by_access_token", diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index b03ca303a2..f7d74dea8e 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -128,6 +128,40 @@ class RegistrationStoreTestCase(unittest.TestCase): with self.assertRaises(StoreError): yield self.store.exchange_refresh_token(last_token, generator.generate) + @defer.inlineCallbacks + def test_user_delete_access_tokens(self): + # add some tokens + generator = TokenGenerator() + refresh_token = generator.generate(self.user_id) + yield self.store.register(self.user_id, self.tokens[0], self.pwhash) + yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], + self.device_id) + yield self.store.add_refresh_token_to_user(self.user_id, refresh_token, + self.device_id) + + # now delete some + yield self.store.user_delete_access_tokens( + self.user_id, device_id=self.device_id, delete_refresh_tokens=True) + + # check they were deleted + user = yield self.store.get_user_by_access_token(self.tokens[1]) + self.assertIsNone(user, "access token was not deleted by device_id") + with self.assertRaises(StoreError): + yield self.store.exchange_refresh_token(refresh_token, + generator.generate) + + # check the one not associated with the device was not deleted + user = yield self.store.get_user_by_access_token(self.tokens[0]) + self.assertEqual(self.user_id, user["name"]) + + # now delete the rest + yield self.store.user_delete_access_tokens( + self.user_id, delete_refresh_tokens=True) + + user = yield self.store.get_user_by_access_token(self.tokens[0]) + self.assertIsNone(user, + "access token was not deleted without device_id") + class TokenGenerator: def __init__(self): -- cgit 1.5.1 From eb359eced44407b1ee9648f10fdf3df63c8d40ad Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 26 Jul 2016 16:46:53 +0100 Subject: Add `create_requester` function Wrap the `Requester` constructor with a function which provides sensible defaults, and use it throughout --- synapse/api/auth.py | 24 +++++++++++------------- synapse/handlers/_base.py | 13 +++++++------ synapse/handlers/profile.py | 12 +++++++----- synapse/handlers/register.py | 16 +++++++++------- synapse/handlers/room_member.py | 20 +++++++++----------- synapse/rest/client/v2_alpha/keys.py | 10 ++++------ synapse/types.py | 33 ++++++++++++++++++++++++++++++++- tests/handlers/test_profile.py | 10 ++++++---- tests/replication/test_resource.py | 20 +++++++++++--------- tests/rest/client/v1/test_profile.py | 13 +++++-------- tests/utils.py | 5 ----- 11 files changed, 101 insertions(+), 75 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index eca8513905..eecf3b0b2a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -13,22 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + +import pymacaroons from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json, SignatureVerifyException - from twisted.internet import defer +from unpaddedbase64 import decode_base64 +import synapse.types from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError -from synapse.types import Requester, UserID, get_domain_from_id -from synapse.util.logutils import log_function +from synapse.types import UserID, get_domain_from_id from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.logutils import log_function from synapse.util.metrics import Measure -from unpaddedbase64 import decode_base64 - -import logging -import pymacaroons logger = logging.getLogger(__name__) @@ -566,8 +566,7 @@ class Auth(object): Args: request - An HTTP request with an access_token query parameter. Returns: - defer.Deferred: resolves to a namedtuple including "user" (UserID) - "access_token_id" (int), "is_guest" (bool) + defer.Deferred: resolves to a ``synapse.types.Requester`` object Raises: AuthError if no user by that token exists or the token is invalid. """ @@ -576,9 +575,7 @@ class Auth(object): user_id = yield self._get_appservice_user_id(request.args) if user_id: request.authenticated_entity = user_id - defer.returnValue( - Requester(UserID.from_string(user_id), "", False) - ) + defer.returnValue(synapse.types.create_requester(user_id)) access_token = request.args["access_token"][0] user_info = yield self.get_user_by_access_token(access_token, rights) @@ -612,7 +609,8 @@ class Auth(object): request.authenticated_entity = user.to_string() - defer.returnValue(Requester(user, token_id, is_guest)) + defer.returnValue(synapse.types.create_requester( + user, token_id, is_guest, device_id)) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 6264aa0d9a..11081a0cd5 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -from synapse.api.errors import LimitExceededError +import synapse.types from synapse.api.constants import Membership, EventTypes -from synapse.types import UserID, Requester - - -import logging +from synapse.api.errors import LimitExceededError +from synapse.types import UserID logger = logging.getLogger(__name__) @@ -124,7 +124,8 @@ class BaseHandler(object): # and having homeservers have their own users leave keeps more # of that decision-making and control local to the guest-having # homeserver. - requester = Requester(target_user, "", True) + requester = synapse.types.create_requester( + target_user, is_guest=True) handler = self.hs.get_handlers().room_member_handler yield handler.update_membership( requester, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 711a6a567f..d9ac09078d 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -13,15 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer +import synapse.types from synapse.api.errors import SynapseError, AuthError, CodeMessageException -from synapse.types import UserID, Requester - +from synapse.types import UserID from ._base import BaseHandler -import logging - logger = logging.getLogger(__name__) @@ -165,7 +165,9 @@ class ProfileHandler(BaseHandler): try: # Assume the user isn't a guest because we don't let guests set # profile or avatar data. - requester = Requester(user, "", False) + # XXX why are we recreating `requester` here for each room? + # what was wrong with the `requester` we were passed? + requester = synapse.types.create_requester(user) yield handler.update_membership( requester, user, diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 94b19d0cb0..b9b5880d64 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -14,18 +14,19 @@ # limitations under the License. """Contains functions for registering clients.""" +import logging +import urllib + from twisted.internet import defer -from synapse.types import UserID, Requester +import synapse.types from synapse.api.errors import ( AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError ) -from ._base import BaseHandler -from synapse.util.async import run_on_reactor from synapse.http.client import CaptchaServerHttpClient - -import logging -import urllib +from synapse.types import UserID +from synapse.util.async import run_on_reactor +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -410,8 +411,9 @@ class RegistrationHandler(BaseHandler): if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) profile_handler = self.hs.get_handlers().profile_handler + requester = synapse.types.create_requester(user) yield profile_handler.set_displayname( - user, Requester(user, token, False), displayname + user, requester, displayname ) defer.returnValue((user_id, token)) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7e616f44fd..8cec8fc4ed 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -14,24 +14,22 @@ # limitations under the License. -from twisted.internet import defer +import logging -from ._base import BaseHandler +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json +from twisted.internet import defer +from unpaddedbase64 import decode_base64 -from synapse.types import UserID, RoomID, Requester +import synapse.types from synapse.api.constants import ( EventTypes, Membership, ) from synapse.api.errors import AuthError, SynapseError, Codes +from synapse.types import UserID, RoomID from synapse.util.async import Linearizer from synapse.util.distributor import user_left_room, user_joined_room - -from signedjson.sign import verify_signed_json -from signedjson.key import decode_verify_key_bytes - -from unpaddedbase64 import decode_base64 - -import logging +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler): ) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) else: - requester = Requester(target_user, None, False) + requester = synapse.types.create_requester(target_user) message_handler = self.hs.get_handlers().message_handler prev_event = message_handler.deduplicate_state_event(event, context) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 89ab39491c..56364af337 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -13,18 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + +import simplejson as json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID - -from canonicaljson import encode_canonical_json - from ._base import client_v2_patterns -import logging -import simplejson as json - logger = logging.getLogger(__name__) diff --git a/synapse/types.py b/synapse/types.py index f639651a73..5349b0c450 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -18,7 +18,38 @@ from synapse.api.errors import SynapseError from collections import namedtuple -Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"]) +Requester = namedtuple("Requester", + ["user", "access_token_id", "is_guest", "device_id"]) +""" +Represents the user making a request + +Attributes: + user (UserID): id of the user making the request + access_token_id (int|None): *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest (bool): True if the user making this request is a guest user + device_id (str|None): device_id which was set at authentication time +""" + + +def create_requester(user_id, access_token_id=None, is_guest=False, + device_id=None): + """ + Create a new ``Requester`` object + + Args: + user_id (str|UserID): id of the user making the request + access_token_id (int|None): *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest (bool): True if the user making this request is a guest user + device_id (str|None): device_id which was set at authentication time + + Returns: + Requester + """ + if not isinstance(user_id, UserID): + user_id = UserID.from_string(user_id) + return Requester(user_id, access_token_id, is_guest, device_id) def get_domain_from_id(string): diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 4f2c14e4ff..f1f664275f 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -19,11 +19,12 @@ from twisted.internet import defer from mock import Mock, NonCallableMock +import synapse.types from synapse.api.errors import AuthError from synapse.handlers.profile import ProfileHandler from synapse.types import UserID -from tests.utils import setup_test_homeserver, requester_for_user +from tests.utils import setup_test_homeserver class ProfileHandlers(object): @@ -86,7 +87,7 @@ class ProfileTestCase(unittest.TestCase): def test_set_my_name(self): yield self.handler.set_displayname( self.frank, - requester_for_user(self.frank), + synapse.types.create_requester(self.frank), "Frank Jr." ) @@ -99,7 +100,7 @@ class ProfileTestCase(unittest.TestCase): def test_set_my_name_noauth(self): d = self.handler.set_displayname( self.frank, - requester_for_user(self.bob), + synapse.types.create_requester(self.bob), "Frank Jr." ) @@ -144,7 +145,8 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar(self): yield self.handler.set_avatar_url( - self.frank, requester_for_user(self.frank), "http://my.server/pic.gif" + self.frank, synapse.types.create_requester(self.frank), + "http://my.server/pic.gif" ) self.assertEquals( diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index 842e3d29d7..e70ac6f14d 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.resource import ReplicationResource -from synapse.types import Requester, UserID +import contextlib +import json +from mock import Mock, NonCallableMock from twisted.internet import defer + +import synapse.types +from synapse.replication.resource import ReplicationResource +from synapse.types import UserID from tests import unittest -from tests.utils import setup_test_homeserver, requester_for_user -from mock import Mock, NonCallableMock -import json -import contextlib +from tests.utils import setup_test_homeserver class ReplicationResourceCase(unittest.TestCase): @@ -61,7 +63,7 @@ class ReplicationResourceCase(unittest.TestCase): def test_events_and_state(self): get = self.get(events="-1", state="-1", timeout="0") yield self.hs.get_handlers().room_creation_handler.create_room( - Requester(self.user, "", False), {} + synapse.types.create_requester(self.user), {} ) code, body = yield get self.assertEquals(code, 200) @@ -144,7 +146,7 @@ class ReplicationResourceCase(unittest.TestCase): def send_text_message(self, room_id, message): handler = self.hs.get_handlers().message_handler event = yield handler.create_and_send_nonmember_event( - requester_for_user(self.user), + synapse.types.create_requester(self.user), { "type": "m.room.message", "content": {"body": "message", "msgtype": "m.text"}, @@ -157,7 +159,7 @@ class ReplicationResourceCase(unittest.TestCase): @defer.inlineCallbacks def create_room(self): result = yield self.hs.get_handlers().room_creation_handler.create_room( - Requester(self.user, "", False), {} + synapse.types.create_requester(self.user), {} ) defer.returnValue(result["room_id"]) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index af02fce8fb..1e95e97538 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -14,17 +14,14 @@ # limitations under the License. """Tests REST events for /profile paths.""" -from tests import unittest -from twisted.internet import defer - from mock import Mock +from twisted.internet import defer -from ....utils import MockHttpResource, setup_test_homeserver - +import synapse.types from synapse.api.errors import SynapseError, AuthError -from synapse.types import Requester, UserID - from synapse.rest.client.v1 import profile +from tests import unittest +from ....utils import MockHttpResource, setup_test_homeserver myid = "@1234ABCD:test" PATH_PREFIX = "/_matrix/client/api/v1" @@ -52,7 +49,7 @@ class ProfileTestCase(unittest.TestCase): ) def _get_user_by_req(request=None, allow_guest=False): - return Requester(UserID.from_string(myid), "", False) + return synapse.types.create_requester(myid) hs.get_v1auth().get_user_by_req = _get_user_by_req diff --git a/tests/utils.py b/tests/utils.py index ed547bc39b..915b934e94 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,7 +20,6 @@ from synapse.storage.prepare_database import prepare_database from synapse.storage.engines import create_engine from synapse.server import HomeServer from synapse.federation.transport import server -from synapse.types import Requester from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.logcontext import LoggingContext @@ -512,7 +511,3 @@ class DeferredMockCallable(object): "call(%s)" % _format_call(c[0], c[1]) for c in calls ]) ) - - -def requester_for_user(user): - return Requester(user, None, False) -- cgit 1.5.1 From d47115ff8bf3ab5952f053db578a519e8e3f930c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 27 Jul 2016 12:18:03 +0100 Subject: Delete e2e keys on device delete --- synapse/handlers/device.py | 4 ++++ synapse/rest/client/v2_alpha/keys.py | 13 +++++++++---- synapse/storage/end_to_end_keys.py | 15 +++++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index eaead50800..f4bf159bb5 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -143,6 +143,10 @@ class DeviceHandler(BaseHandler): delete_refresh_tokens=True, ) + yield self.store.delete_e2e_keys_by_device( + user_id=user_id, device_id=device_id + ) + @defer.inlineCallbacks def update_device(self, user_id, device_id, content): """ Update the given device diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 0bf32a089b..4629f4bfde 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -86,10 +86,6 @@ class KeyUploadServlet(RestServlet): raise synapse.api.errors.SynapseError( 400, "Can only upload keys for current device" ) - - self.device_handler.check_device_registered( - user_id, device_id, "unknown device" - ) else: device_id = requester.device_id @@ -131,6 +127,15 @@ class KeyUploadServlet(RestServlet): user_id, device_id, time_now, key_list ) + # the device should have been registered already, but it may have been + # deleted due to a race with a DELETE request. Or we may be using an + # old access_token without an associated device_id. Either way, we + # need to double-check the device is registered to avoid ending up with + # keys without a corresponding device. + self.device_handler.check_device_registered( + user_id, device_id, "unknown device" + ) + result = yield self.store.count_e2e_one_time_keys(user_id, device_id) defer.returnValue((200, {"one_time_key_counts": result})) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 2e89066515..62b7790e91 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import twisted.internet.defer + from ._base import SQLBaseStore @@ -123,3 +125,16 @@ class EndToEndKeyStore(SQLBaseStore): return self.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) + + @twisted.internet.defer.inlineCallbacks + def delete_e2e_keys_by_device(self, user_id, device_id): + yield self._simple_delete( + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="delete_e2e_device_keys_by_device" + ) + yield self._simple_delete( + table="e2e_one_time_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="delete_e2e_one_time_keys_by_device" + ) -- cgit 1.5.1 From 05f6447301ddc72cec7564f9d39f3e16aaa728c6 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Wed, 27 Jul 2016 17:54:26 +0100 Subject: Forbid non-ASes from registering users whose names begin with '_' (SYN-738) --- synapse/handlers/register.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'synapse/handlers') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index b9b5880d64..dd75c4fecf 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -53,6 +53,13 @@ class RegistrationHandler(BaseHandler): Codes.INVALID_USERNAME ) + if localpart[0] == '_': + raise SynapseError( + 400, + "User ID may not begin with _", + Codes.INVALID_USERNAME + ) + user = UserID(localpart, self.hs.hostname) user_id = user.to_string() -- cgit 1.5.1 From 1e2740caabe348e4131fe6bd2d777fc7483909a4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jul 2016 16:08:33 +0100 Subject: Handle the case of missing auth events when joining a room --- synapse/handlers/federation.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3f138daf17..cab7efb5db 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -124,7 +124,7 @@ class FederationHandler(BaseHandler): try: event_stream_id, max_stream_id = yield self._persist_auth_tree( - auth_chain, state, event + origin, auth_chain, state, event ) except AuthError as e: raise FederationError( @@ -637,7 +637,7 @@ class FederationHandler(BaseHandler): pass event_stream_id, max_stream_id = yield self._persist_auth_tree( - auth_chain, state, event + origin, auth_chain, state, event ) with PreserveLoggingContext(): @@ -1155,7 +1155,7 @@ class FederationHandler(BaseHandler): ) @defer.inlineCallbacks - def _persist_auth_tree(self, auth_events, state, event): + def _persist_auth_tree(self, origin, auth_events, state, event): """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. Persists the event seperately. @@ -1172,7 +1172,7 @@ class FederationHandler(BaseHandler): event_map = { e.event_id: e - for e in auth_events + for e in itertools.chain(auth_events, state, [event]) } create_event = None @@ -1181,10 +1181,29 @@ class FederationHandler(BaseHandler): create_event = e break + missing_auth_events = set() + for e in itertools.chain(auth_events, state, [event]): + for e_id, _ in e.auth_events: + if e_id not in event_map: + missing_auth_events.add(e_id) + + for e_id in missing_auth_events: + m_ev = yield self.replication_layer.get_pdu( + [origin], + e_id, + outlier=True, + timeout=10000, + ) + if m_ev and m_ev.event_id == e_id: + event_map[e_id] = m_ev + else: + logger.info("Failed to find auth event %r", e_id) + for e in itertools.chain(auth_events, state, [event]): auth_for_e = { (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] for e_id, _ in e.auth_events + if e_id in event_map } if create_event: auth_for_e[(EventTypes.Create, "")] = create_event -- cgit 1.5.1 From 3d13c3a2952263c38111fcf95d625e316416b52b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Jul 2016 10:45:05 +0100 Subject: Update docstring --- synapse/handlers/federation.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index cab7efb5db..9583629388 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1160,6 +1160,12 @@ class FederationHandler(BaseHandler): state and event. Then persists the auth chain and state atomically. Persists the event seperately. + Args: + origin (str): Where the events came from + auth_events (list) + state (list) + event (Event) + Returns: 2-tuple of (event_stream_id, max_stream_id) from the persist_event call for `event` -- cgit 1.5.1 From c51a52f3002abf4597952e07759c6ab3016e3497 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Jul 2016 11:17:04 +0100 Subject: Mention that func will fetch auth events --- synapse/handlers/federation.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 9583629388..1323235b62 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1160,6 +1160,8 @@ class FederationHandler(BaseHandler): state and event. Then persists the auth chain and state atomically. Persists the event seperately. + Will attempt to fetch missing auth events. + Args: origin (str): Where the events came from auth_events (list) -- cgit 1.5.1 From 986615b0b21271959adb9d64291761244e4175bd Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 1 Aug 2016 18:02:07 +0100 Subject: Move e2e query logic into a handler --- synapse/handlers/e2e_keys.py | 67 ++++++++++++++++++++++++++++++++++++ synapse/rest/client/v2_alpha/keys.py | 46 ++++--------------------- synapse/server.py | 45 +++++++++++++----------- synapse/server.pyi | 4 +++ 4 files changed, 102 insertions(+), 60 deletions(-) create mode 100644 synapse/handlers/e2e_keys.py (limited to 'synapse/handlers') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py new file mode 100644 index 0000000000..73a14cf952 --- /dev/null +++ b/synapse/handlers/e2e_keys.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging + +from twisted.internet import defer + +import synapse.types +from ._base import BaseHandler + +logger = logging.getLogger(__name__) + + +class E2eKeysHandler(BaseHandler): + def __init__(self, hs): + super(E2eKeysHandler, self).__init__(hs) + self.store = hs.get_datastore() + self.federation = hs.get_replication_layer() + self.is_mine = hs.is_mine + + @defer.inlineCallbacks + def query_devices(self, query_body): + local_query = [] + remote_queries = {} + for user_id, device_ids in query_body.get("device_keys", {}).items(): + user = synapse.types.UserID.from_string(user_id) + if self.is_mine(user): + if not device_ids: + local_query.append((user_id, None)) + else: + for device_id in device_ids: + local_query.append((user_id, device_id)) + else: + remote_queries.setdefault(user.domain, {})[user_id] = list( + device_ids + ) + results = yield self.store.get_e2e_device_keys(local_query) + + json_result = {} + for user_id, device_keys in results.items(): + for device_id, json_bytes in device_keys.items(): + json_result.setdefault(user_id, {})[ + device_id] = json.loads( + json_bytes + ) + + for destination, device_keys in remote_queries.items(): + remote_result = yield self.federation.query_client_keys( + destination, {"device_keys": device_keys} + ) + for user_id, keys in remote_result["device_keys"].items(): + if user_id in device_keys: + json_result[user_id] = keys + defer.returnValue((200, {"device_keys": json_result})) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index dc1d4d8fc6..705a0b6c17 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -186,17 +186,19 @@ class KeyQueryServlet(RestServlet): ) def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ super(KeyQueryServlet, self).__init__() - self.store = hs.get_datastore() self.auth = hs.get_auth() - self.federation = hs.get_replication_layer() - self.is_mine = hs.is_mine + self.e2e_keys_handler = hs.get_e2e_keys_handler() @defer.inlineCallbacks def on_POST(self, request, user_id, device_id): yield self.auth.get_user_by_req(request) body = parse_json_object_from_request(request) - result = yield self.handle_request(body) + result = yield self.e2e_keys_handler.query_devices(body) defer.returnValue(result) @defer.inlineCallbacks @@ -205,45 +207,11 @@ class KeyQueryServlet(RestServlet): auth_user_id = requester.user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] - result = yield self.handle_request( + result = yield self.e2e_keys_handler.query_devices( {"device_keys": {user_id: device_ids}} ) defer.returnValue(result) - @defer.inlineCallbacks - def handle_request(self, body): - local_query = [] - remote_queries = {} - for user_id, device_ids in body.get("device_keys", {}).items(): - user = UserID.from_string(user_id) - if self.is_mine(user): - if not device_ids: - local_query.append((user_id, None)) - else: - for device_id in device_ids: - local_query.append((user_id, device_id)) - else: - remote_queries.setdefault(user.domain, {})[user_id] = list( - device_ids - ) - results = yield self.store.get_e2e_device_keys(local_query) - - json_result = {} - for user_id, device_keys in results.items(): - for device_id, json_bytes in device_keys.items(): - json_result.setdefault(user_id, {})[device_id] = json.loads( - json_bytes - ) - - for destination, device_keys in remote_queries.items(): - remote_result = yield self.federation.query_client_keys( - destination, {"device_keys": device_keys} - ) - for user_id, keys in remote_result["device_keys"].items(): - if user_id in device_keys: - json_result[user_id] = keys - defer.returnValue((200, {"device_keys": json_result})) - class OneTimeKeyServlet(RestServlet): """ diff --git a/synapse/server.py b/synapse/server.py index e8b166990d..6bb4988309 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -19,39 +19,38 @@ # partial one for unit test mocking. # Imports required for the default HomeServer() implementation -from twisted.web.client import BrowserLikePolicyForHTTPS +import logging + from twisted.enterprise import adbapi +from twisted.web.client import BrowserLikePolicyForHTTPS -from synapse.appservice.scheduler import ApplicationServiceScheduler +from synapse.api.auth import Auth +from synapse.api.filtering import Filtering +from synapse.api.ratelimiting import Ratelimiter from synapse.appservice.api import ApplicationServiceApi +from synapse.appservice.scheduler import ApplicationServiceScheduler +from synapse.crypto.keyring import Keyring +from synapse.events.builder import EventBuilderFactory from synapse.federation import initialize_http_replication -from synapse.handlers.device import DeviceHandler -from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory -from synapse.notifier import Notifier -from synapse.api.auth import Auth from synapse.handlers import Handlers +from synapse.handlers.appservice import ApplicationServicesHandler +from synapse.handlers.auth import AuthHandler +from synapse.handlers.device import DeviceHandler +from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.presence import PresenceHandler +from synapse.handlers.room import RoomListHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler -from synapse.handlers.room import RoomListHandler -from synapse.handlers.auth import AuthHandler -from synapse.handlers.appservice import ApplicationServicesHandler +from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory +from synapse.http.matrixfederationclient import MatrixFederationHttpClient +from synapse.notifier import Notifier +from synapse.push.pusherpool import PusherPool +from synapse.rest.media.v1.media_repository import MediaRepository from synapse.state import StateHandler from synapse.storage import DataStore +from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor -from synapse.streams.events import EventSources -from synapse.api.ratelimiting import Ratelimiter -from synapse.crypto.keyring import Keyring -from synapse.push.pusherpool import PusherPool -from synapse.events.builder import EventBuilderFactory -from synapse.api.filtering import Filtering -from synapse.rest.media.v1.media_repository import MediaRepository - -from synapse.http.matrixfederationclient import MatrixFederationHttpClient - -import logging - logger = logging.getLogger(__name__) @@ -94,6 +93,7 @@ class HomeServer(object): 'room_list_handler', 'auth_handler', 'device_handler', + 'e2e_keys_handler', 'application_service_api', 'application_service_scheduler', 'application_service_handler', @@ -202,6 +202,9 @@ class HomeServer(object): def build_device_handler(self): return DeviceHandler(self) + def build_e2e_keys_handler(self): + return E2eKeysHandler(self) + def build_application_service_api(self): return ApplicationServiceApi(self) diff --git a/synapse/server.pyi b/synapse/server.pyi index 902f725c06..c0aa868c4f 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,6 +1,7 @@ import synapse.handlers import synapse.handlers.auth import synapse.handlers.device +import synapse.handlers.e2e_keys import synapse.storage import synapse.state @@ -14,6 +15,9 @@ class HomeServer(object): def get_device_handler(self) -> synapse.handlers.device.DeviceHandler: pass + def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler: + pass + def get_handlers(self) -> synapse.handlers.Handlers: pass -- cgit 1.5.1 From 1efee2f52b931ddcd90e87d06c7ea614da2c9cd0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 2 Aug 2016 18:06:31 +0100 Subject: E2E keys: Make federation query share code with client query Refactor the e2e query handler to separate out the local query, and then make the federation handler use it. --- synapse/federation/federation_server.py | 20 +----- synapse/federation/transport/server.py | 4 +- synapse/handlers/e2e_keys.py | 115 +++++++++++++++++++++++++------- 3 files changed, 92 insertions(+), 47 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 85f5e752fe..e637f2a8bd 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -348,27 +348,9 @@ class FederationServer(FederationBase): (200, send_content) ) - @defer.inlineCallbacks @log_function def on_query_client_keys(self, origin, content): - query = [] - for user_id, device_ids in content.get("device_keys", {}).items(): - if not device_ids: - query.append((user_id, None)) - else: - for device_id in device_ids: - query.append((user_id, device_id)) - - results = yield self.store.get_e2e_device_keys(query) - - json_result = {} - for user_id, device_keys in results.items(): - for device_id, json_bytes in device_keys.items(): - json_result.setdefault(user_id, {})[device_id] = json.loads( - json_bytes - ) - - defer.returnValue({"device_keys": json_result}) + return self.on_query_request("client_keys", content) @defer.inlineCallbacks @log_function diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 26fa88ae84..1a88413d18 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -367,10 +367,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet): PATH = "/user/keys/query" - @defer.inlineCallbacks def on_POST(self, origin, content, query): - response = yield self.handler.on_query_client_keys(origin, content) - defer.returnValue((200, response)) + return self.handler.on_query_client_keys(origin, content) class FederationClientKeysClaimServlet(BaseFederationServlet): diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 73a14cf952..9c7e9494d6 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -13,12 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import json import logging from twisted.internet import defer +from synapse.api import errors import synapse.types + from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -29,39 +32,101 @@ class E2eKeysHandler(BaseHandler): super(E2eKeysHandler, self).__init__(hs) self.store = hs.get_datastore() self.federation = hs.get_replication_layer() - self.is_mine = hs.is_mine + self.is_mine_id = hs.is_mine_id + + # doesn't really work as part of the generic query API, because the + # query request requires an object POST, but we abuse the + # "query handler" interface. + self.federation.register_query_handler( + "client_keys", self.on_federation_query_client_keys + ) @defer.inlineCallbacks def query_devices(self, query_body): - local_query = [] - remote_queries = {} - for user_id, device_ids in query_body.get("device_keys", {}).items(): + """ Handle a device key query from a client + + { + "device_keys": { + "": [""] + } + } + -> + { + "device_keys": { + "": { + "": { + ... + } + } + } + } + """ + device_keys_query = query_body.get("device_keys", {}) + + # separate users by domain. + # make a map from domain to user_id to device_ids + queries_by_domain = collections.defaultdict(dict) + for user_id, device_ids in device_keys_query.items(): user = synapse.types.UserID.from_string(user_id) - if self.is_mine(user): - if not device_ids: - local_query.append((user_id, None)) - else: - for device_id in device_ids: - local_query.append((user_id, device_id)) + queries_by_domain[user.domain][user_id] = device_ids + + # do the queries + # TODO: do these in parallel + results = {} + for destination, destination_query in queries_by_domain.items(): + if destination == self.hs.hostname: + res = yield self.query_local_devices(destination_query) else: - remote_queries.setdefault(user.domain, {})[user_id] = list( - device_ids + res = yield self.federation.query_client_keys( + destination, {"device_keys": destination_query} ) + res = res["device_keys"] + for user_id, keys in res.items(): + if user_id in destination_query: + results[user_id] = keys + + defer.returnValue((200, {"device_keys": results})) + + @defer.inlineCallbacks + def query_local_devices(self, query): + """Get E2E device keys for local users + + Args: + query (dict[string, list[string]|None): map from user_id to a list + of devices to query (None for all devices) + + Returns: + defer.Deferred: (resolves to dict[string, dict[string, dict]]): + map from user_id -> device_id -> device details + """ + local_query = [] + + for user_id, device_ids in query.items(): + if not self.is_mine_id(user_id): + logger.warning("Request for keys for non-local user %s", + user_id) + raise errors.SynapseError(400, "Not a user here") + + if not device_ids: + local_query.append((user_id, None)) + else: + for device_id in device_ids: + local_query.append((user_id, device_id)) + results = yield self.store.get_e2e_device_keys(local_query) - json_result = {} + # un-jsonify the results + json_result = collections.defaultdict(dict) for user_id, device_keys in results.items(): for device_id, json_bytes in device_keys.items(): - json_result.setdefault(user_id, {})[ - device_id] = json.loads( - json_bytes - ) + json_result[user_id][device_id] = json.loads(json_bytes) - for destination, device_keys in remote_queries.items(): - remote_result = yield self.federation.query_client_keys( - destination, {"device_keys": device_keys} - ) - for user_id, keys in remote_result["device_keys"].items(): - if user_id in device_keys: - json_result[user_id] = keys - defer.returnValue((200, {"device_keys": json_result})) + defer.returnValue(json_result) + + @defer.inlineCallbacks + def on_federation_query_client_keys(self, query_body): + """ Handle a device key query from a federated server + """ + device_keys_query = query_body.get("device_keys", {}) + res = yield self.query_local_devices(device_keys_query) + defer.returnValue({"device_keys": res}) -- cgit 1.5.1 From 4fec5e57be72e5374342637b4062aeff0df6adc3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 3 Aug 2016 11:39:39 +0100 Subject: Default device_display_name to null It turns out that it's more useful to return a null device display name (and let clients decide how to handle it: eg, falling back to device_id) than using a constant string like "unknown device". --- synapse/handlers/device.py | 2 +- synapse/rest/client/v2_alpha/keys.py | 4 +--- .../storage/schema/delta/33/devices_for_e2e_keys.sql | 2 +- .../33/devices_for_e2e_keys_clear_unknown_device.sql | 20 ++++++++++++++++++++ 4 files changed, 23 insertions(+), 5 deletions(-) create mode 100644 synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql (limited to 'synapse/handlers') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index f4bf159bb5..fcbe7f8e6b 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -29,7 +29,7 @@ class DeviceHandler(BaseHandler): @defer.inlineCallbacks def check_device_registered(self, user_id, device_id, - initial_device_display_name): + initial_device_display_name = None): """ If the given device has not been registered, register it with the supplied display name. diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index dc1d4d8fc6..5fa33aceea 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -130,9 +130,7 @@ class KeyUploadServlet(RestServlet): # old access_token without an associated device_id. Either way, we # need to double-check the device is registered to avoid ending up with # keys without a corresponding device. - self.device_handler.check_device_registered( - user_id, device_id, "unknown device" - ) + self.device_handler.check_device_registered(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id) defer.returnValue((200, {"one_time_key_counts": result})) diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql index 140f2b63e0..aa4a3b9f2f 100644 --- a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql +++ b/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql @@ -16,4 +16,4 @@ -- make sure that we have a device record for each set of E2E keys, so that the -- user can delete them if they like. INSERT INTO devices - SELECT user_id, device_id, 'unknown device' FROM e2e_device_keys_json; + SELECT user_id, device_id, NULL FROM e2e_device_keys_json; diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql new file mode 100644 index 0000000000..6671573398 --- /dev/null +++ b/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- a previous version of the "devices_for_e2e_keys" delta set all the device +-- names to "unknown device". This wasn't terribly helpful +UPDATE devices + SET display_name = NULL + WHERE display_name = 'unknown device'; -- cgit 1.5.1 From a843868fe942aaa9d32fe858476d1b459813a854 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 3 Aug 2016 14:24:33 +0100 Subject: E2eKeysHandler: minor tweaks PR feedback --- synapse/handlers/e2e_keys.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 9c7e9494d6..1312cdf5ab 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -22,17 +22,15 @@ from twisted.internet import defer from synapse.api import errors import synapse.types -from ._base import BaseHandler - logger = logging.getLogger(__name__) -class E2eKeysHandler(BaseHandler): +class E2eKeysHandler(object): def __init__(self, hs): - super(E2eKeysHandler, self).__init__(hs) self.store = hs.get_datastore() self.federation = hs.get_replication_layer() self.is_mine_id = hs.is_mine_id + self.server_name = hs.hostname # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the @@ -74,7 +72,7 @@ class E2eKeysHandler(BaseHandler): # TODO: do these in parallel results = {} for destination, destination_query in queries_by_domain.items(): - if destination == self.hs.hostname: + if destination == self.server_name: res = yield self.query_local_devices(destination_query) else: res = yield self.federation.query_client_keys( -- cgit 1.5.1 From a6f5cc65d9c4b1b1adf909355c03e72c456627a6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 3 Aug 2016 14:30:06 +0100 Subject: PEP8 --- synapse/handlers/device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index fcbe7f8e6b..8d630c6b1a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -29,7 +29,7 @@ class DeviceHandler(BaseHandler): @defer.inlineCallbacks def check_device_registered(self, user_id, device_id, - initial_device_display_name = None): + initial_device_display_name=None): """ If the given device has not been registered, register it with the supplied display name. -- cgit 1.5.1 From 91fa69e0299167dad4df9831b8d175a99564b266 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 3 Aug 2016 11:48:32 +0100 Subject: keys/query: return all users which were asked for In the situation where all of a user's devices get deleted, we want to indicate this to a client, so we want to return an empty dictionary, rather than nothing at all. --- synapse/handlers/e2e_keys.py | 9 +++++--- tests/handlers/test_e2e_keys.py | 46 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 tests/handlers/test_e2e_keys.py (limited to 'synapse/handlers') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 1312cdf5ab..950fc927b1 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -99,6 +99,7 @@ class E2eKeysHandler(object): """ local_query = [] + result_dict = {} for user_id, device_ids in query.items(): if not self.is_mine_id(user_id): logger.warning("Request for keys for non-local user %s", @@ -111,15 +112,17 @@ class E2eKeysHandler(object): for device_id in device_ids: local_query.append((user_id, device_id)) + # make sure that each queried user appears in the result dict + result_dict[user_id] = {} + results = yield self.store.get_e2e_device_keys(local_query) # un-jsonify the results - json_result = collections.defaultdict(dict) for user_id, device_keys in results.items(): for device_id, json_bytes in device_keys.items(): - json_result[user_id][device_id] = json.loads(json_bytes) + result_dict[user_id][device_id] = json.loads(json_bytes) - defer.returnValue(json_result) + defer.returnValue(result_dict) @defer.inlineCallbacks def on_federation_query_client_keys(self, query_body): diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py new file mode 100644 index 0000000000..878a54dc34 --- /dev/null +++ b/tests/handlers/test_e2e_keys.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +from twisted.internet import defer + +import synapse.api.errors +import synapse.handlers.e2e_keys + +import synapse.storage +from tests import unittest, utils + + +class E2eKeysHandlerTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs) + self.hs = None # type: synapse.server.HomeServer + self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler + + @defer.inlineCallbacks + def setUp(self): + self.hs = yield utils.setup_test_homeserver( + handlers=None, + replication_layer=mock.Mock(), + ) + self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) + + @defer.inlineCallbacks + def test_query_local_devices_no_devices(self): + """If the user has no devices, we expect an empty list. + """ + local_user = "@boris:" + self.hs.hostname + res = yield self.handler.query_local_devices({local_user: None}) + self.assertDictEqual(res, {local_user: {}}) -- cgit 1.5.1 From 68264d7404214d30d32310991c89ea4a234c319a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 3 Aug 2016 07:46:57 +0100 Subject: Include device name in /keys/query response Add an 'unsigned' section which includes the device display name. --- synapse/handlers/e2e_keys.py | 11 +++-- synapse/storage/end_to_end_keys.py | 60 ++++++++++++++++------- tests/storage/test_end_to_end_keys.py | 92 +++++++++++++++++++++++++++++++++++ 3 files changed, 143 insertions(+), 20 deletions(-) create mode 100644 tests/storage/test_end_to_end_keys.py (limited to 'synapse/handlers') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 950fc927b1..bb69089b91 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -117,10 +117,15 @@ class E2eKeysHandler(object): results = yield self.store.get_e2e_device_keys(local_query) - # un-jsonify the results + # Build the result structure, un-jsonify the results, and add the + # "unsigned" section for user_id, device_keys in results.items(): - for device_id, json_bytes in device_keys.items(): - result_dict[user_id][device_id] = json.loads(json_bytes) + for device_id, device_info in device_keys.items(): + r = json.loads(device_info["key_json"]) + r["unsigned"] = { + "device_display_name": device_info["device_display_name"], + } + result_dict[user_id][device_id] = r defer.returnValue(result_dict) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 62b7790e91..5c8ed3e492 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections import twisted.internet.defer @@ -38,24 +39,49 @@ class EndToEndKeyStore(SQLBaseStore): query_list(list): List of pairs of user_ids and device_ids. Returns: Dict mapping from user-id to dict mapping from device_id to - key json byte strings. + dict containing "key_json", "device_display_name". """ - def _get_e2e_device_keys(txn): - result = {} - for user_id, device_id in query_list: - user_result = result.setdefault(user_id, {}) - keyvalues = {"user_id": user_id} - if device_id: - keyvalues["device_id"] = device_id - rows = self._simple_select_list_txn( - txn, table="e2e_device_keys_json", - keyvalues=keyvalues, - retcols=["device_id", "key_json"] - ) - for row in rows: - user_result[row["device_id"]] = row["key_json"] - return result - return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) + if not query_list: + return {} + + return self.runInteraction( + "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list + ) + + def _get_e2e_device_keys_txn(self, txn, query_list): + query_clauses = [] + query_params = [] + + for (user_id, device_id) in query_list: + query_clause = "k.user_id = ?" + query_params.append(user_id) + + if device_id: + query_clause += " AND k.device_id = ?" + query_params.append(device_id) + + query_clauses.append(query_clause) + + sql = ( + "SELECT k.user_id, k.device_id, " + " d.display_name AS device_display_name, " + " k.key_json" + " FROM e2e_device_keys_json k" + " LEFT JOIN devices d ON d.user_id = k.user_id" + " AND d.device_id = k.device_id" + " WHERE %s" + ) % ( + " OR ".join("("+q+")" for q in query_clauses) + ) + + txn.execute(sql, query_params) + rows = self.cursor_to_dict(txn) + + result = collections.defaultdict(dict) + for row in rows: + result[row["user_id"]][row["device_id"]] = row + + return result def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): def _add_e2e_one_time_keys(txn): diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py new file mode 100644 index 0000000000..0ebc6dafe8 --- /dev/null +++ b/tests/storage/test_end_to_end_keys.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import synapse.api.errors +import tests.unittest +import tests.utils + + +class EndToEndKeyStoreTestCase(tests.unittest.TestCase): + def __init__(self, *args, **kwargs): + super(EndToEndKeyStoreTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.DataStore + + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_key_without_device_name(self): + now = 1470174257070 + json = '{ "key": "value" }' + + yield self.store.set_e2e_device_keys( + "user", "device", now, json) + + res = yield self.store.get_e2e_device_keys((("user", "device"),)) + self.assertIn("user", res) + self.assertIn("device", res["user"]) + dev = res["user"]["device"] + self.assertDictContainsSubset({ + "key_json": json, + "device_display_name": None, + }, dev) + + @defer.inlineCallbacks + def test_get_key_with_device_name(self): + now = 1470174257070 + json = '{ "key": "value" }' + + yield self.store.set_e2e_device_keys( + "user", "device", now, json) + yield self.store.store_device( + "user", "device", "display_name" + ) + + res = yield self.store.get_e2e_device_keys((("user", "device"),)) + self.assertIn("user", res) + self.assertIn("device", res["user"]) + dev = res["user"]["device"] + self.assertDictContainsSubset({ + "key_json": json, + "device_display_name": "display_name", + }, dev) + + + @defer.inlineCallbacks + def test_multiple_devices(self): + now = 1470174257070 + + yield self.store.set_e2e_device_keys( + "user1", "device1", now, 'json11') + yield self.store.set_e2e_device_keys( + "user1", "device2", now, 'json12') + yield self.store.set_e2e_device_keys( + "user2", "device1", now, 'json21') + yield self.store.set_e2e_device_keys( + "user2", "device2", now, 'json22') + + res = yield self.store.get_e2e_device_keys((("user1", "device1"), + ("user2", "device2"))) + self.assertIn("user1", res) + self.assertIn("device1", res["user1"]) + self.assertNotIn("device2", res["user1"]) + self.assertIn("user2", res) + self.assertNotIn("device1", res["user2"]) + self.assertIn("device2", res["user2"]) -- cgit 1.5.1 From f131cd9e53b8131c5f7fa71765c717eadd001f16 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 4 Aug 2016 10:59:51 +0100 Subject: keys/query: Omit device displayname if null ... which makes it more consistent with user displaynames. --- synapse/handlers/e2e_keys.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index bb69089b91..2c7bfd91ed 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -122,9 +122,10 @@ class E2eKeysHandler(object): for user_id, device_keys in results.items(): for device_id, device_info in device_keys.items(): r = json.loads(device_info["key_json"]) - r["unsigned"] = { - "device_display_name": device_info["device_display_name"], - } + r["unsigned"] = {} + display_name = device_info["device_display_name"] + if display_name is not None: + r["unsigned"]["device_display_name"] = display_name result_dict[user_id][device_id] = r defer.returnValue(result_dict) -- cgit 1.5.1 From 93acf49e9b6d601c65ccb73b46ffe24c9b0f26a2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 5 Aug 2016 12:59:04 +0100 Subject: Fix backfill auth events --- synapse/handlers/federation.py | 71 +++++++++++++++++++++++++++++------------- 1 file changed, 50 insertions(+), 21 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 187bfc4315..618cb53629 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -335,31 +335,58 @@ class FederationHandler(BaseHandler): state_events.update({s.event_id: s for s in state}) events_to_state[e_id] = state - seen_events = yield self.store.have_events( - set(auth_events.keys()) | set(state_events.keys()) - ) - - all_events = events + state_events.values() + auth_events.values() required_auth = set( - a_id for event in all_events for a_id, _ in event.auth_events + a_id + for event in events + state_events.values() + auth_events.values() + for a_id, _ in event.auth_events ) - + auth_events.update({ + e_id: event_map[e_id] for e_id in required_auth if e_id in event_map + }) missing_auth = required_auth - set(auth_events) - if missing_auth: + failed_to_fetch = set() + + # Try and fetch any missing auth events from both DB and remote servers. + # We repeatedly do this until we stop finding new auth events. + while missing_auth - failed_to_fetch: logger.info("Missing auth for backfill: %r", missing_auth) - results = yield defer.gatherResults( - [ - self.replication_layer.get_pdu( - [dest], - event_id, - outlier=True, - timeout=10000, - ) - for event_id in missing_auth - ], - consumeErrors=True - ).addErrback(unwrapFirstError) - auth_events.update({a.event_id: a for a in results}) + ret_events = yield self.store.get_events(missing_auth - failed_to_fetch) + auth_events.update(ret_events) + + required_auth.update( + a_id for event in ret_events.values() for a_id, _ in event.auth_events + ) + missing_auth = required_auth - set(auth_events) + + if missing_auth - failed_to_fetch: + logger.info( + "Fetching missing auth for backfill: %r", + missing_auth - failed_to_fetch + ) + + results = yield defer.gatherResults( + [ + self.replication_layer.get_pdu( + [dest], + event_id, + outlier=True, + timeout=10000, + ) + for event_id in missing_auth - failed_to_fetch + ], + consumeErrors=True + ).addErrback(unwrapFirstError) + auth_events.update({a.event_id: a for a in results}) + required_auth.update( + a_id for event in results for a_id, _ in event.auth_events + ) + missing_auth = required_auth - set(auth_events) + + failed_to_fetch = missing_auth - set(auth_events) + + seen_events = yield self.store.have_events( + set(auth_events.keys()) | set(state_events.keys()) + ) ev_infos = [] for a in auth_events.values(): @@ -372,6 +399,7 @@ class FederationHandler(BaseHandler): (auth_events[a_id].type, auth_events[a_id].state_key): auth_events[a_id] for a_id, _ in a.auth_events + if a_id in auth_events } }) @@ -383,6 +411,7 @@ class FederationHandler(BaseHandler): (auth_events[a_id].type, auth_events[a_id].state_key): auth_events[a_id] for a_id, _ in event_map[e_id].auth_events + if a_id in auth_events } }) -- cgit 1.5.1 From 6fe6a6f0299c97086a552eda75570eaa66ff2598 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 8 Aug 2016 16:34:07 +0100 Subject: Fix login with m.login.token login with token (as used by CAS auth) was broken by 067596d, such that it always returned a 401. --- synapse/api/auth.py | 45 +++++++++++++++++++++++++------------- synapse/handlers/auth.py | 17 ++++----------- synapse/server.pyi | 4 ++++ tests/handlers/test_auth.py | 53 +++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 87 insertions(+), 32 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 59db76debc..0db26fcfd7 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -675,27 +675,18 @@ class Auth(object): try: macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) - user_prefix = "user_id = " - user = None - user_id = None - guest = False - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(user_prefix): - user_id = caveat.caveat_id[len(user_prefix):] - user = UserID.from_string(user_id) - elif caveat.caveat_id == "guest = true": - guest = True + user_id = self.get_user_id_from_macaroon(macaroon) + user = UserID.from_string(user_id) self.validate_macaroon( macaroon, rights, self.hs.config.expire_access_token, user_id=user_id, ) - if user is None: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", - errcode=Codes.UNKNOWN_TOKEN - ) + guest = False + for caveat in macaroon.caveats: + if caveat.caveat_id == "guest = true": + guest = True if guest: ret = { @@ -743,6 +734,29 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) + def get_user_id_from_macaroon(self, macaroon): + """Retrieve the user_id given by the caveats on the macaroon. + + Does *not* validate the macaroon. + + Args: + macaroon (pymacaroons.Macaroon): The macaroon to validate + + Returns: + (str) user id + + Raises: + AuthError if there is no user_id caveat in the macaroon + """ + user_prefix = "user_id = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(user_prefix): + return caveat.caveat_id[len(user_prefix):] + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): """ validate that a Macaroon is understood by and was signed by this server. @@ -754,6 +768,7 @@ class Auth(object): verify_expiry(bool): Whether to verify whether the macaroon has expired. This should really always be True, but no clients currently implement token refresh, so we can't enforce expiry yet. + user_id (str): The user_id required """ v = pymacaroons.Verifier() v.satisfy_exact("gen = 1") diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 2e138f328f..1d3641b7a7 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -720,10 +720,11 @@ class AuthHandler(BaseHandler): def validate_short_term_login_token_and_get_user_id(self, login_token): try: - macaroon = pymacaroons.Macaroon.deserialize(login_token) auth_api = self.hs.get_auth() - auth_api.validate_macaroon(macaroon, "login", True) - return self.get_user_from_macaroon(macaroon) + macaroon = pymacaroons.Macaroon.deserialize(login_token) + user_id = auth_api.get_user_id_from_macaroon(macaroon) + auth_api.validate_macaroon(macaroon, "login", True, user_id) + return user_id except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) @@ -736,16 +737,6 @@ class AuthHandler(BaseHandler): macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon - def get_user_from_macaroon(self, macaroon): - user_prefix = "user_id = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(user_prefix): - return caveat.caveat_id[len(user_prefix):] - raise AuthError( - self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token", - errcode=Codes.UNKNOWN_TOKEN - ) - @defer.inlineCallbacks def set_password(self, user_id, newpassword, requester=None): password_hash = self.hash(newpassword) diff --git a/synapse/server.pyi b/synapse/server.pyi index c0aa868c4f..9570df5537 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,3 +1,4 @@ +import synapse.api.auth import synapse.handlers import synapse.handlers.auth import synapse.handlers.device @@ -6,6 +7,9 @@ import synapse.storage import synapse.state class HomeServer(object): + def get_auth(self) -> synapse.api.auth.Auth: + pass + def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: pass diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 21077cbe9a..366cf97ae4 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -14,12 +14,13 @@ # limitations under the License. import pymacaroons +from twisted.internet import defer +import synapse +import synapse.api.errors from synapse.handlers.auth import AuthHandler from tests import unittest from tests.utils import setup_test_homeserver -from twisted.internet import defer - class AuthHandlers(object): def __init__(self, hs): @@ -31,11 +32,12 @@ class AuthTestCase(unittest.TestCase): def setUp(self): self.hs = yield setup_test_homeserver(handlers=None) self.hs.handlers = AuthHandlers(self.hs) + self.auth_handler = self.hs.handlers.auth_handler def test_token_is_a_macaroon(self): self.hs.config.macaroon_secret_key = "this key is a huge secret" - token = self.hs.handlers.auth_handler.generate_access_token("some_user") + token = self.auth_handler.generate_access_token("some_user") # Check that we can parse the thing with pymacaroons macaroon = pymacaroons.Macaroon.deserialize(token) # The most basic of sanity checks @@ -46,7 +48,7 @@ class AuthTestCase(unittest.TestCase): self.hs.config.macaroon_secret_key = "this key is a massive secret" self.hs.clock.now = 5000 - token = self.hs.handlers.auth_handler.generate_access_token("a_user") + token = self.auth_handler.generate_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) def verify_gen(caveat): @@ -67,3 +69,46 @@ class AuthTestCase(unittest.TestCase): v.satisfy_general(verify_type) v.satisfy_general(verify_expiry) v.verify(macaroon, self.hs.config.macaroon_secret_key) + + def test_short_term_login_token_gives_user_id(self): + self.hs.clock.now = 1000 + + token = self.auth_handler.generate_short_term_login_token( + "a_user", 5000 + ) + + self.assertEqual( + "a_user", + self.auth_handler.validate_short_term_login_token_and_get_user_id( + token + ) + ) + + # when we advance the clock, the token should be rejected + self.hs.clock.now = 6000 + with self.assertRaises(synapse.api.errors.AuthError): + self.auth_handler.validate_short_term_login_token_and_get_user_id( + token + ) + + def test_short_term_login_token_cannot_replace_user_id(self): + token = self.auth_handler.generate_short_term_login_token( + "a_user", 5000 + ) + macaroon = pymacaroons.Macaroon.deserialize(token) + + self.assertEqual( + "a_user", + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) + ) + + # add another "user_id" caveat, which might allow us to override the + # user_id. + macaroon.add_first_party_caveat("user_id = b_user") + + with self.assertRaises(synapse.api.errors.AuthError): + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) -- cgit 1.5.1 From 79ebfbe7c62400a2b63d67fb65b1abce29d8bf38 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 9 Aug 2016 16:29:28 +0100 Subject: /login: Respond with a 403 when we get an invalid m.login.token --- synapse/handlers/auth.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 1d3641b7a7..82998a81ce 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -719,14 +719,14 @@ class AuthHandler(BaseHandler): return macaroon.serialize() def validate_short_term_login_token_and_get_user_id(self, login_token): + auth_api = self.hs.get_auth() try: - auth_api = self.hs.get_auth() macaroon = pymacaroons.Macaroon.deserialize(login_token) user_id = auth_api.get_user_id_from_macaroon(macaroon) auth_api.validate_macaroon(macaroon, "login", True, user_id) return user_id - except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): - raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) + except Exception: + raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) def _generate_base_macaroon(self, user_id): macaroon = pymacaroons.Macaroon( -- cgit 1.5.1 From 11fdfaf03b67810f3d289241f772d3177e3c6b7e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Aug 2016 11:55:15 +0100 Subject: Only resign our own events --- synapse/handlers/federation.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 618cb53629..55d11122ba 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1093,16 +1093,17 @@ class FederationHandler(BaseHandler): ) if event: - # FIXME: This is a temporary work around where we occasionally - # return events slightly differently than when they were - # originally signed - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] + if self.hs.is_mine_id(event.event_id): + # FIXME: This is a temporary work around where we occasionally + # return events slightly differently than when they were + # originally signed + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) ) - ) if do_auth: in_room = yield self.auth.check_host_in_room( -- cgit 1.5.1 From 7f41bcbeec64da874c10a70f8a0b8e3f9f0626d8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Aug 2016 13:22:20 +0100 Subject: Correctly auth /event/ requests --- synapse/handlers/federation.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 55d11122ba..2f8959db92 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -249,7 +249,7 @@ class FederationHandler(BaseHandler): if ev.type != EventTypes.Member: continue try: - domain = UserID.from_string(ev.state_key).domain + domain = get_domain_from_id(ev.state_key) except: continue @@ -1106,13 +1106,14 @@ class FederationHandler(BaseHandler): ) if do_auth: - in_room = yield self.auth.check_host_in_room( - event.room_id, - origin + events = yield self._filter_events_for_server( + origin, event.room_id, [event] ) - if not in_room: + if not events: raise AuthError(403, "Host not in room.") + event = events[0] + defer.returnValue(event) else: defer.returnValue(None) -- cgit 1.5.1 From 739ea29d1ecbdf414db1f5062c8a2aeaa519f4ff Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Aug 2016 13:32:23 +0100 Subject: Also check if server is in the room --- synapse/handlers/federation.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2f8959db92..ff6bb475b5 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1106,11 +1106,16 @@ class FederationHandler(BaseHandler): ) if do_auth: + in_room = yield self.auth.check_host_in_room( + event.room_id, + origin + ) + if not in_room: + raise AuthError(403, "Host not in room.") + events = yield self._filter_events_for_server( origin, event.room_id, [event] ) - if not events: - raise AuthError(403, "Host not in room.") event = events[0] -- cgit 1.5.1