From 87528f07561d16dbf35aeebdfcecb111ed385b4f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 21 Sep 2016 11:46:28 +0100 Subject: Support /initialSync in synchrotron worker --- synapse/handlers/initial_sync.py | 443 +++++++++++++++++++++++++++++++++++++++ synapse/handlers/message.py | 381 +-------------------------------- 2 files changed, 446 insertions(+), 378 deletions(-) create mode 100644 synapse/handlers/initial_sync.py (limited to 'synapse/handlers') diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py new file mode 100644 index 0000000000..fbfa5a0281 --- /dev/null +++ b/synapse/handlers/initial_sync.py @@ -0,0 +1,443 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import AuthError, Codes +from synapse.events.utils import serialize_event +from synapse.events.validator import EventValidator +from synapse.streams.config import PaginationConfig +from synapse.types import ( + UserID, StreamToken, +) +from synapse.util import unwrapFirstError +from synapse.util.async import concurrently_execute +from synapse.util.caches.snapshot_cache import SnapshotCache +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.visibility import filter_events_for_client + +from ._base import BaseHandler + +import logging + + +logger = logging.getLogger(__name__) + + +class InitialSyncHandler(BaseHandler): + def __init__(self, hs): + super(InitialSyncHandler, self).__init__(hs) + self.hs = hs + self.state = hs.get_state_handler() + self.clock = hs.get_clock() + self.validator = EventValidator() + self.snapshot_cache = SnapshotCache() + + def snapshot_all_rooms(self, user_id=None, pagin_config=None, + as_client_event=True, include_archived=False): + """Retrieve a snapshot of all rooms the user is invited or has joined. + + This snapshot may include messages for all rooms where the user is + joined, depending on the pagination config. + + Args: + user_id (str): The ID of the user making the request. + pagin_config (synapse.api.streams.PaginationConfig): The pagination + config used to determine how many messages *PER ROOM* to return. + as_client_event (bool): True to get events in client-server format. + include_archived (bool): True to get rooms that the user has left + Returns: + A list of dicts with "room_id" and "membership" keys for all rooms + the user is currently invited or joined in on. Rooms where the user + is joined on, may return a "messages" key with messages, depending + on the specified PaginationConfig. + """ + key = ( + user_id, + pagin_config.from_token, + pagin_config.to_token, + pagin_config.direction, + pagin_config.limit, + as_client_event, + include_archived, + ) + now_ms = self.clock.time_msec() + result = self.snapshot_cache.get(now_ms, key) + if result is not None: + return result + + return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms( + user_id, pagin_config, as_client_event, include_archived + )) + + @defer.inlineCallbacks + def _snapshot_all_rooms(self, user_id=None, pagin_config=None, + as_client_event=True, include_archived=False): + + memberships = [Membership.INVITE, Membership.JOIN] + if include_archived: + memberships.append(Membership.LEAVE) + + room_list = yield self.store.get_rooms_for_user_where_membership_is( + user_id=user_id, membership_list=memberships + ) + + user = UserID.from_string(user_id) + + rooms_ret = [] + + now_token = yield self.hs.get_event_sources().get_current_token() + + presence_stream = self.hs.get_event_sources().sources["presence"] + pagination_config = PaginationConfig(from_token=now_token) + presence, _ = yield presence_stream.get_pagination_rows( + user, pagination_config.get_source_config("presence"), None + ) + + receipt_stream = self.hs.get_event_sources().sources["receipt"] + receipt, _ = yield receipt_stream.get_pagination_rows( + user, pagination_config.get_source_config("receipt"), None + ) + + tags_by_room = yield self.store.get_tags_for_user(user_id) + + account_data, account_data_by_room = ( + yield self.store.get_account_data_for_user(user_id) + ) + + public_room_ids = yield self.store.get_public_room_ids() + + limit = pagin_config.limit + if limit is None: + limit = 10 + + @defer.inlineCallbacks + def handle_room(event): + d = { + "room_id": event.room_id, + "membership": event.membership, + "visibility": ( + "public" if event.room_id in public_room_ids + else "private" + ), + } + + if event.membership == Membership.INVITE: + time_now = self.clock.time_msec() + d["inviter"] = event.sender + + invite_event = yield self.store.get_event(event.event_id) + d["invite"] = serialize_event(invite_event, time_now, as_client_event) + + rooms_ret.append(d) + + if event.membership not in (Membership.JOIN, Membership.LEAVE): + return + + try: + if event.membership == Membership.JOIN: + room_end_token = now_token.room_key + deferred_room_state = self.state_handler.get_current_state( + event.room_id + ) + elif event.membership == Membership.LEAVE: + room_end_token = "s%d" % (event.stream_ordering,) + deferred_room_state = self.store.get_state_for_events( + [event.event_id], None + ) + deferred_room_state.addCallback( + lambda states: states[event.event_id] + ) + + (messages, token), current_state = yield preserve_context_over_deferred( + defer.gatherResults( + [ + preserve_fn(self.store.get_recent_events_for_room)( + event.room_id, + limit=limit, + end_token=room_end_token, + ), + deferred_room_state, + ] + ) + ).addErrback(unwrapFirstError) + + messages = yield filter_events_for_client( + self.store, user_id, messages + ) + + start_token = now_token.copy_and_replace("room_key", token[0]) + end_token = now_token.copy_and_replace("room_key", token[1]) + time_now = self.clock.time_msec() + + d["messages"] = { + "chunk": [ + serialize_event(m, time_now, as_client_event) + for m in messages + ], + "start": start_token.to_string(), + "end": end_token.to_string(), + } + + d["state"] = [ + serialize_event(c, time_now, as_client_event) + for c in current_state.values() + ] + + account_data_events = [] + tags = tags_by_room.get(event.room_id) + if tags: + account_data_events.append({ + "type": "m.tag", + "content": {"tags": tags}, + }) + + account_data = account_data_by_room.get(event.room_id, {}) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + d["account_data"] = account_data_events + except: + logger.exception("Failed to get snapshot") + + yield concurrently_execute(handle_room, room_list, 10) + + account_data_events = [] + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + ret = { + "rooms": rooms_ret, + "presence": presence, + "account_data": account_data_events, + "receipts": receipt, + "end": now_token.to_string(), + } + + defer.returnValue(ret) + + @defer.inlineCallbacks + def room_initial_sync(self, requester, room_id, pagin_config=None): + """Capture the a snapshot of a room. If user is currently a member of + the room this will be what is currently in the room. If the user left + the room this will be what was in the room when they left. + + Args: + requester(Requester): The user to get a snapshot for. + room_id(str): The room to get a snapshot of. + pagin_config(synapse.streams.config.PaginationConfig): + The pagination config used to determine how many messages to + return. + Raises: + AuthError if the user wasn't in the room. + Returns: + A JSON serialisable dict with the snapshot of the room. + """ + + user_id = requester.user.to_string() + + membership, member_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id, + ) + is_peeking = member_event_id is None + + if membership == Membership.JOIN: + result = yield self._room_initial_sync_joined( + user_id, room_id, pagin_config, membership, is_peeking + ) + elif membership == Membership.LEAVE: + result = yield self._room_initial_sync_parted( + user_id, room_id, pagin_config, membership, member_event_id, is_peeking + ) + + account_data_events = [] + tags = yield self.store.get_tags_for_room(user_id, room_id) + if tags: + account_data_events.append({ + "type": "m.tag", + "content": {"tags": tags}, + }) + + account_data = yield self.store.get_account_data_for_room(user_id, room_id) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + result["account_data"] = account_data_events + + defer.returnValue(result) + + @defer.inlineCallbacks + def _room_initial_sync_parted(self, user_id, room_id, pagin_config, + membership, member_event_id, is_peeking): + room_state = yield self.store.get_state_for_events( + [member_event_id], None + ) + + room_state = room_state[member_event_id] + + limit = pagin_config.limit if pagin_config else None + if limit is None: + limit = 10 + + stream_token = yield self.store.get_stream_token_for_event( + member_event_id + ) + + messages, token = yield self.store.get_recent_events_for_room( + room_id, + limit=limit, + end_token=stream_token + ) + + messages = yield filter_events_for_client( + self.store, user_id, messages, is_peeking=is_peeking + ) + + start_token = StreamToken.START.copy_and_replace("room_key", token[0]) + end_token = StreamToken.START.copy_and_replace("room_key", token[1]) + + time_now = self.clock.time_msec() + + defer.returnValue({ + "membership": membership, + "room_id": room_id, + "messages": { + "chunk": [serialize_event(m, time_now) for m in messages], + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": [serialize_event(s, time_now) for s in room_state.values()], + "presence": [], + "receipts": [], + }) + + @defer.inlineCallbacks + def _room_initial_sync_joined(self, user_id, room_id, pagin_config, + membership, is_peeking): + current_state = yield self.state.get_current_state( + room_id=room_id, + ) + + # TODO: These concurrently + time_now = self.clock.time_msec() + state = [ + serialize_event(x, time_now) + for x in current_state.values() + ] + + now_token = yield self.hs.get_event_sources().get_current_token() + + limit = pagin_config.limit if pagin_config else None + if limit is None: + limit = 10 + + room_members = [ + m for m in current_state.values() + if m.type == EventTypes.Member + and m.content["membership"] == Membership.JOIN + ] + + presence_handler = self.hs.get_presence_handler() + + @defer.inlineCallbacks + def get_presence(): + states = yield presence_handler.get_states( + [m.user_id for m in room_members], + as_event=True, + ) + + defer.returnValue(states) + + @defer.inlineCallbacks + def get_receipts(): + receipts_handler = self.hs.get_handlers().receipts_handler + receipts = yield receipts_handler.get_receipts_for_room( + room_id, + now_token.receipt_key + ) + defer.returnValue(receipts) + + presence, receipts, (messages, token) = yield defer.gatherResults( + [ + preserve_fn(get_presence)(), + preserve_fn(get_receipts)(), + preserve_fn(self.store.get_recent_events_for_room)( + room_id, + limit=limit, + end_token=now_token.room_key, + ) + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + + messages = yield filter_events_for_client( + self.store, user_id, messages, is_peeking=is_peeking, + ) + + start_token = now_token.copy_and_replace("room_key", token[0]) + end_token = now_token.copy_and_replace("room_key", token[1]) + + time_now = self.clock.time_msec() + + ret = { + "room_id": room_id, + "messages": { + "chunk": [serialize_event(m, time_now) for m in messages], + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": state, + "presence": presence, + "receipts": receipts, + } + if not is_peeking: + ret["membership"] = membership + + defer.returnValue(ret) + + @defer.inlineCallbacks + def _check_in_room_or_world_readable(self, room_id, user_id): + try: + # check_user_was_in_room will return the most recent membership + # event for the user if: + # * The user is a non-guest user, and was ever in the room + # * The user is a guest user, and has joined the room + # else it will throw. + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + defer.returnValue((member_event.membership, member_event.event_id)) + return + except AuthError: + visibility = yield self.state_handler.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" + ) + if ( + visibility and + visibility.content["history_visibility"] == "world_readable" + ): + defer.returnValue((Membership.JOIN, None)) + return + raise AuthError( + 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 178209a209..30ea9630f7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -21,14 +21,11 @@ from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.push.action_generator import ActionGenerator -from synapse.streams.config import PaginationConfig from synapse.types import ( - UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id + UserID, RoomAlias, RoomStreamToken, get_domain_from_id ) -from synapse.util import unwrapFirstError -from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock -from synapse.util.caches.snapshot_cache import SnapshotCache -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.async import run_on_reactor, ReadWriteLock +from synapse.util.logcontext import preserve_fn from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client @@ -49,7 +46,6 @@ class MessageHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = SnapshotCache() self.pagination_lock = ReadWriteLock() @@ -392,377 +388,6 @@ class MessageHandler(BaseHandler): [serialize_event(c, now) for c in room_state.values()] ) - def snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): - """Retrieve a snapshot of all rooms the user is invited or has joined. - - This snapshot may include messages for all rooms where the user is - joined, depending on the pagination config. - - Args: - user_id (str): The ID of the user making the request. - pagin_config (synapse.api.streams.PaginationConfig): The pagination - config used to determine how many messages *PER ROOM* to return. - as_client_event (bool): True to get events in client-server format. - include_archived (bool): True to get rooms that the user has left - Returns: - A list of dicts with "room_id" and "membership" keys for all rooms - the user is currently invited or joined in on. Rooms where the user - is joined on, may return a "messages" key with messages, depending - on the specified PaginationConfig. - """ - key = ( - user_id, - pagin_config.from_token, - pagin_config.to_token, - pagin_config.direction, - pagin_config.limit, - as_client_event, - include_archived, - ) - now_ms = self.clock.time_msec() - result = self.snapshot_cache.get(now_ms, key) - if result is not None: - return result - - return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - )) - - @defer.inlineCallbacks - def _snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): - - memberships = [Membership.INVITE, Membership.JOIN] - if include_archived: - memberships.append(Membership.LEAVE) - - room_list = yield self.store.get_rooms_for_user_where_membership_is( - user_id=user_id, membership_list=memberships - ) - - user = UserID.from_string(user_id) - - rooms_ret = [] - - now_token = yield self.hs.get_event_sources().get_current_token() - - presence_stream = self.hs.get_event_sources().sources["presence"] - pagination_config = PaginationConfig(from_token=now_token) - presence, _ = yield presence_stream.get_pagination_rows( - user, pagination_config.get_source_config("presence"), None - ) - - receipt_stream = self.hs.get_event_sources().sources["receipt"] - receipt, _ = yield receipt_stream.get_pagination_rows( - user, pagination_config.get_source_config("receipt"), None - ) - - tags_by_room = yield self.store.get_tags_for_user(user_id) - - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user(user_id) - ) - - public_room_ids = yield self.store.get_public_room_ids() - - limit = pagin_config.limit - if limit is None: - limit = 10 - - @defer.inlineCallbacks - def handle_room(event): - d = { - "room_id": event.room_id, - "membership": event.membership, - "visibility": ( - "public" if event.room_id in public_room_ids - else "private" - ), - } - - if event.membership == Membership.INVITE: - time_now = self.clock.time_msec() - d["inviter"] = event.sender - - invite_event = yield self.store.get_event(event.event_id) - d["invite"] = serialize_event(invite_event, time_now, as_client_event) - - rooms_ret.append(d) - - if event.membership not in (Membership.JOIN, Membership.LEAVE): - return - - try: - if event.membership == Membership.JOIN: - room_end_token = now_token.room_key - deferred_room_state = self.state_handler.get_current_state( - event.room_id - ) - elif event.membership == Membership.LEAVE: - room_end_token = "s%d" % (event.stream_ordering,) - deferred_room_state = self.store.get_state_for_events( - [event.event_id], None - ) - deferred_room_state.addCallback( - lambda states: states[event.event_id] - ) - - (messages, token), current_state = yield preserve_context_over_deferred( - defer.gatherResults( - [ - preserve_fn(self.store.get_recent_events_for_room)( - event.room_id, - limit=limit, - end_token=room_end_token, - ), - deferred_room_state, - ] - ) - ).addErrback(unwrapFirstError) - - messages = yield filter_events_for_client( - self.store, user_id, messages - ) - - start_token = now_token.copy_and_replace("room_key", token[0]) - end_token = now_token.copy_and_replace("room_key", token[1]) - time_now = self.clock.time_msec() - - d["messages"] = { - "chunk": [ - serialize_event(m, time_now, as_client_event) - for m in messages - ], - "start": start_token.to_string(), - "end": end_token.to_string(), - } - - d["state"] = [ - serialize_event(c, time_now, as_client_event) - for c in current_state.values() - ] - - account_data_events = [] - tags = tags_by_room.get(event.room_id) - if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) - - account_data = account_data_by_room.get(event.room_id, {}) - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - d["account_data"] = account_data_events - except: - logger.exception("Failed to get snapshot") - - yield concurrently_execute(handle_room, room_list, 10) - - account_data_events = [] - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - ret = { - "rooms": rooms_ret, - "presence": presence, - "account_data": account_data_events, - "receipts": receipt, - "end": now_token.to_string(), - } - - defer.returnValue(ret) - - @defer.inlineCallbacks - def room_initial_sync(self, requester, room_id, pagin_config=None): - """Capture the a snapshot of a room. If user is currently a member of - the room this will be what is currently in the room. If the user left - the room this will be what was in the room when they left. - - Args: - requester(Requester): The user to get a snapshot for. - room_id(str): The room to get a snapshot of. - pagin_config(synapse.streams.config.PaginationConfig): - The pagination config used to determine how many messages to - return. - Raises: - AuthError if the user wasn't in the room. - Returns: - A JSON serialisable dict with the snapshot of the room. - """ - - user_id = requester.user.to_string() - - membership, member_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id, - ) - is_peeking = member_event_id is None - - if membership == Membership.JOIN: - result = yield self._room_initial_sync_joined( - user_id, room_id, pagin_config, membership, is_peeking - ) - elif membership == Membership.LEAVE: - result = yield self._room_initial_sync_parted( - user_id, room_id, pagin_config, membership, member_event_id, is_peeking - ) - - account_data_events = [] - tags = yield self.store.get_tags_for_room(user_id, room_id) - if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) - - account_data = yield self.store.get_account_data_for_room(user_id, room_id) - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - result["account_data"] = account_data_events - - defer.returnValue(result) - - @defer.inlineCallbacks - def _room_initial_sync_parted(self, user_id, room_id, pagin_config, - membership, member_event_id, is_peeking): - room_state = yield self.store.get_state_for_events( - [member_event_id], None - ) - - room_state = room_state[member_event_id] - - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - - stream_token = yield self.store.get_stream_token_for_event( - member_event_id - ) - - messages, token = yield self.store.get_recent_events_for_room( - room_id, - limit=limit, - end_token=stream_token - ) - - messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking - ) - - start_token = StreamToken.START.copy_and_replace("room_key", token[0]) - end_token = StreamToken.START.copy_and_replace("room_key", token[1]) - - time_now = self.clock.time_msec() - - defer.returnValue({ - "membership": membership, - "room_id": room_id, - "messages": { - "chunk": [serialize_event(m, time_now) for m in messages], - "start": start_token.to_string(), - "end": end_token.to_string(), - }, - "state": [serialize_event(s, time_now) for s in room_state.values()], - "presence": [], - "receipts": [], - }) - - @defer.inlineCallbacks - def _room_initial_sync_joined(self, user_id, room_id, pagin_config, - membership, is_peeking): - current_state = yield self.state.get_current_state( - room_id=room_id, - ) - - # TODO: These concurrently - time_now = self.clock.time_msec() - state = [ - serialize_event(x, time_now) - for x in current_state.values() - ] - - now_token = yield self.hs.get_event_sources().get_current_token() - - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - - room_members = [ - m for m in current_state.values() - if m.type == EventTypes.Member - and m.content["membership"] == Membership.JOIN - ] - - presence_handler = self.hs.get_presence_handler() - - @defer.inlineCallbacks - def get_presence(): - states = yield presence_handler.get_states( - [m.user_id for m in room_members], - as_event=True, - ) - - defer.returnValue(states) - - @defer.inlineCallbacks - def get_receipts(): - receipts_handler = self.hs.get_handlers().receipts_handler - receipts = yield receipts_handler.get_receipts_for_room( - room_id, - now_token.receipt_key - ) - defer.returnValue(receipts) - - presence, receipts, (messages, token) = yield defer.gatherResults( - [ - preserve_fn(get_presence)(), - preserve_fn(get_receipts)(), - preserve_fn(self.store.get_recent_events_for_room)( - room_id, - limit=limit, - end_token=now_token.room_key, - ) - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - - messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking, - ) - - start_token = now_token.copy_and_replace("room_key", token[0]) - end_token = now_token.copy_and_replace("room_key", token[1]) - - time_now = self.clock.time_msec() - - ret = { - "room_id": room_id, - "messages": { - "chunk": [serialize_event(m, time_now) for m in messages], - "start": start_token.to_string(), - "end": end_token.to_string(), - }, - "state": state, - "presence": presence, - "receipts": receipts, - } - if not is_peeking: - ret["membership"] = membership - - defer.returnValue(ret) - @measure_func("_create_new_client_event") @defer.inlineCallbacks def _create_new_client_event(self, builder, prev_event_ids=None): -- cgit 1.4.1 From 90c070c8503a380367f02f98e56b68fe07405413 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 21 Sep 2016 13:17:08 +0100 Subject: Add total_room_count_estimate to /publicRooms --- synapse/handlers/room_list.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'synapse/handlers') diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 5a533682c5..b04aea0110 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -125,6 +125,8 @@ class RoomListHandler(BaseHandler): if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0 ] + total_room_count = len(rooms_to_scan) + if since_token: # Filter out rooms we've already returned previously # `since_token.current_limit` is the index of the last room we @@ -188,6 +190,7 @@ class RoomListHandler(BaseHandler): results = { "chunk": chunk, + "total_room_count_estimate": total_room_count, } if since_token: -- cgit 1.4.1 From 1168cbd54db059cefdf968077f6b14163de6c04c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 22 Sep 2016 10:56:53 +0100 Subject: Allow invites via 3pid to bypass sender sig check When a server sends a third party invite another server may be the one that the inviting user registers with. In this case it is that remote server that will issue an actual invitation, and wants to do it "in the name of" the original invitee. However, the new proper invite will not be signed by the original server, and thus other servers would reject the invite if it was seen as coming from the original user. To fix this, a special case has been added to the auth rules whereby another server can send an invite "in the name of" another server's user, so long as that user had previously issued a third party invite that is now being accepted. --- synapse/api/auth.py | 17 ++++++++++++++++- synapse/handlers/federation.py | 12 ++++++------ 2 files changed, 22 insertions(+), 7 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 98a50f0948..d60c1b15ae 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -72,7 +72,7 @@ class Auth(object): auth_events = { (e.type, e.state_key): e for e in auth_events.values() } - self.check(event, auth_events=auth_events, do_sig_check=False) + self.check(event, auth_events=auth_events, do_sig_check=do_sig_check) def check(self, event, auth_events, do_sig_check=True): """ Checks if this event is correctly authed. @@ -92,9 +92,21 @@ class Auth(object): raise AuthError(500, "Event has no room_id: %s" % event) sender_domain = get_domain_from_id(event.sender) + event_id_domain = get_domain_from_id(event.event_id) + + is_invite_via_3pid = ( + event.type == EventTypes.Member + and event.membership == Membership.INVITE + and "third_party_invite" in event.content + ) # Check the sender's domain has signed the event if do_sig_check and not event.signatures.get(sender_domain): + if not is_invite_via_3pid: + raise AuthError(403, "Event not signed by sender's server") + + # Check the event_id's domain has signed the event + if do_sig_check and not event.signatures.get(event_id_domain): raise AuthError(403, "Event not signed by sending server") if auth_events is None: @@ -491,6 +503,9 @@ class Auth(object): if not invite_event: return False + if invite_event.sender != event.sender: + return False + if event.user_id != invite_event.user_id: return False diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f7cb3c1bb2..a393263e1e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1922,15 +1922,15 @@ class FederationHandler(BaseHandler): original_invite = yield self.store.get_event( original_invite_id, allow_none=True ) - if not original_invite: + if original_invite: + display_name = original_invite.content["display_name"] + event_dict["content"]["third_party_invite"]["display_name"] = display_name + else: logger.info( - "Could not find invite event for third_party_invite - " - "discarding: %s" % (event_dict,) + "Could not find invite event for third_party_invite: %r", + event_dict ) - return - display_name = original_invite.content["display_name"] - event_dict["content"]["third_party_invite"]["display_name"] = display_name builder = self.event_builder_factory.new(event_dict) EventValidator().validate_new(builder) message_handler = self.hs.get_handlers().message_handler -- cgit 1.4.1 From 2e9ee3096907573773d3f0e4ff22dd014b8253c8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 22 Sep 2016 11:59:46 +0100 Subject: Add comments --- synapse/api/auth.py | 3 +++ synapse/handlers/federation.py | 3 +++ 2 files changed, 6 insertions(+) (limited to 'synapse/handlers') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 377bfcc482..5bd250992a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -103,6 +103,9 @@ class Auth(object): # Check the sender's domain has signed the event if not event.signatures.get(sender_domain): + # We allow invites via 3pid to have a sender from a differnt + # HS, as the sender must match the sender of the original + # 3pid invite. This is checked further down. if not is_invite_via_3pid: raise AuthError(403, "Event not signed by sender's server") diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a393263e1e..2d801bad47 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1930,6 +1930,9 @@ class FederationHandler(BaseHandler): "Could not find invite event for third_party_invite: %r", event_dict ) + # We don't discard here as this is not the appropriate place to do + # auth checks. If we need the invite and don't have it then the + # auth check code will explode appropriately. builder = self.event_builder_factory.new(event_dict) EventValidator().validate_new(builder) -- cgit 1.4.1 From 22578545a05944284eab3ba7646e2c1c5c36e359 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 23 Sep 2016 13:56:14 +0100 Subject: Time out typing over federation --- synapse/federation/federation_client.py | 2 - synapse/handlers/typing.py | 175 +++++++++++++++++++------------- synapse/rest/client/v1/room.py | 5 +- tests/handlers/test_typing.py | 7 +- tests/rest/client/v1/test_typing.py | 5 +- tests/utils.py | 9 +- 6 files changed, 120 insertions(+), 83 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 06d0320b1a..94e76b1978 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -136,9 +136,7 @@ class FederationClient(FederationBase): sent_edus_counter.inc() - # TODO, add errback, etc. self._transaction_queue.enqueue_edu(edu, key=key) - return defer.succeed(None) @log_function def send_device_messages(self, destination): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0548b81c34..505a68d142 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -16,10 +16,9 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError -from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, -) +from synapse.util.logcontext import preserve_fn from synapse.util.metrics import Measure +from synapse.util.wheel_timer import WheelTimer from synapse.types import UserID, get_domain_from_id import logging @@ -35,6 +34,13 @@ logger = logging.getLogger(__name__) RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) +# How often we expect remote servers to resend us presence. +FEDERATION_TIMEOUT = 60 * 1000 + +# How often to resend typing across federation. +FEDERATION_PING_INTERVAL = 40 * 1000 + + class TypingHandler(object): def __init__(self, hs): self.store = hs.get_datastore() @@ -44,7 +50,10 @@ class TypingHandler(object): self.notifier = hs.get_notifier() self.state = hs.get_state_handler() + self.hs = hs + self.clock = hs.get_clock() + self.wheel_timer = WheelTimer() self.federation = hs.get_replication_layer() @@ -53,7 +62,7 @@ class TypingHandler(object): hs.get_distributor().observe("user_left_room", self.user_left_room) self._member_typing_until = {} # clock time we expect to stop - self._member_typing_timer = {} # deferreds to manage theabove + self._member_last_federation_poke = {} # map room IDs to serial numbers self._room_serials = {} @@ -61,12 +70,41 @@ class TypingHandler(object): # map room IDs to sets of users currently typing self._room_typing = {} - def tearDown(self): - """Cancels all the pending timers. - Normally this shouldn't be needed, but it's required from unit tests - to avoid a "Reactor was unclean" warning.""" - for t in self._member_typing_timer.values(): - self.clock.cancel_call_later(t) + self.clock.looping_call( + self._handle_timeouts, + 5000, + ) + + def _handle_timeouts(self): + logger.info("Handling typing timeout") + + now = self.clock.time_msec() + + members = set(self.wheel_timer.fetch(now)) + + for member in members: + if not self.is_typing(member): + # Nothing to do if they're no longer typing + continue + + until = self._member_typing_until.get(member, None) + if not until or until < now: + logger.info("Timing out typing for: %s", member.user_id) + preserve_fn(self._stopped_typing)(member) + continue + + # Check if we need to resend a keep alive over federation for this + # user. + if self.hs.is_mine_id(member.user_id): + last_fed_poke = self._member_last_federation_poke.get(member, None) + if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL < now: + preserve_fn(self._push_remote)( + member=member, + typing=True + ) + + def is_typing(self, member): + return member.user_id in self._room_typing.get(member.room_id, []) @defer.inlineCallbacks def started_typing(self, target_user, auth_user, room_id, timeout): @@ -85,23 +123,23 @@ class TypingHandler(object): "%s has started typing in %s", target_user_id, room_id ) - until = self.clock.time_msec() + timeout member = RoomMember(room_id=room_id, user_id=target_user_id) - was_present = member in self._member_typing_until + was_present = member.user_id in self._room_typing.get(room_id, set()) - if member in self._member_typing_timer: - self.clock.cancel_call_later(self._member_typing_timer[member]) + now = self.clock.time_msec() + self._member_typing_until[member] = now + timeout - def _cb(): - logger.debug( - "%s has timed out in %s", target_user.to_string(), room_id - ) - self._stopped_typing(member) + self.wheel_timer.insert( + now=now, + obj=member, + then=now + timeout, + ) - self._member_typing_until[member] = until - self._member_typing_timer[member] = self.clock.call_later( - timeout / 1000.0, _cb + self.wheel_timer.insert( + now=now, + obj=member, + then=now + FEDERATION_PING_INTERVAL, ) if was_present: @@ -109,8 +147,7 @@ class TypingHandler(object): defer.returnValue(None) yield self._push_update( - room_id=room_id, - user_id=target_user_id, + member=member, typing=True, ) @@ -133,10 +170,6 @@ class TypingHandler(object): member = RoomMember(room_id=room_id, user_id=target_user_id) - if member in self._member_typing_timer: - self.clock.cancel_call_later(self._member_typing_timer[member]) - del self._member_typing_timer[member] - yield self._stopped_typing(member) @defer.inlineCallbacks @@ -148,57 +181,53 @@ class TypingHandler(object): @defer.inlineCallbacks def _stopped_typing(self, member): - if member not in self._member_typing_until: + if member.user_id not in self._room_typing.get(member.room_id, set()): # No point defer.returnValue(None) + self._member_typing_until.pop(member, None) + self._member_last_federation_poke.pop(member, None) + yield self._push_update( - room_id=member.room_id, - user_id=member.user_id, + member=member, typing=False, ) - del self._member_typing_until[member] - - if member in self._member_typing_timer: - # Don't cancel it - either it already expired, or the real - # stopped_typing() will cancel it - del self._member_typing_timer[member] - @defer.inlineCallbacks - def _push_update(self, room_id, user_id, typing): - users = yield self.state.get_current_user_in_room(room_id) - domains = set(get_domain_from_id(u) for u in users) + def _push_update(self, member, typing): + if self.hs.is_mine_id(member.user_id): + # Only send updates for changes to our own users. + yield self._push_remote(member, typing) + + self._push_update_local( + member=member, + typing=typing + ) - deferreds = [] - for domain in domains: - if domain == self.server_name: - preserve_fn(self._push_update_local)( - room_id=room_id, - user_id=user_id, - typing=typing - ) - else: - deferreds.append(preserve_fn(self.federation.send_edu)( + @defer.inlineCallbacks + def _push_remote(self, member, typing): + users = yield self.state.get_current_user_in_room(member.room_id) + self._member_last_federation_poke[member] = self.clock.time_msec() + for domain in set(get_domain_from_id(u) for u in users): + if domain != self.server_name: + self.federation.send_edu( destination=domain, edu_type="m.typing", content={ - "room_id": room_id, - "user_id": user_id, + "room_id": member.room_id, + "user_id": member.user_id, "typing": typing, }, - key=(room_id, user_id), - )) - - yield preserve_context_over_deferred( - defer.DeferredList(deferreds, consumeErrors=True) - ) + key=member, + ) @defer.inlineCallbacks def _recv_edu(self, origin, content): room_id = content["room_id"] user_id = content["user_id"] + member = RoomMember(user_id=user_id, room_id=room_id) + # Check that the string is a valid user id user = UserID.from_string(user_id) @@ -213,26 +242,32 @@ class TypingHandler(object): domains = set(get_domain_from_id(u) for u in users) if self.server_name in domains: + logger.info("Got typing update from %s: %r", user_id, content) + now = self.clock.time_msec() + self._member_typing_until[member] = now + FEDERATION_TIMEOUT + self.wheel_timer.insert( + now=now, + obj=member, + then=now + FEDERATION_TIMEOUT, + ) self._push_update_local( - room_id=room_id, - user_id=user_id, + member=member, typing=content["typing"] ) - def _push_update_local(self, room_id, user_id, typing): - room_set = self._room_typing.setdefault(room_id, set()) + def _push_update_local(self, member, typing): + room_set = self._room_typing.setdefault(member.room_id, set()) if typing: - room_set.add(user_id) + room_set.add(member.user_id) else: - room_set.discard(user_id) + room_set.discard(member.user_id) self._latest_room_serial += 1 - self._room_serials[room_id] = self._latest_room_serial + self._room_serials[member.room_id] = self._latest_room_serial - with PreserveLoggingContext(): - self.notifier.on_new_event( - "typing_key", self._latest_room_serial, rooms=[room_id] - ) + self.notifier.on_new_event( + "typing_key", self._latest_room_serial, rooms=[member.room_id] + ) def get_all_typing_updates(self, last_id, current_id): # TODO: Work out a way to do this without scanning the entire state. diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 20889e4af0..010fbc7c32 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -705,12 +705,15 @@ class RoomTypingRestServlet(ClientV1RestServlet): yield self.presence_handler.bump_presence_active_time(requester.user) + # Limit timeout to stop people from setting silly typing timeouts. + timeout = min(content.get("timeout", 30000), 120000) + if content["typing"]: yield self.typing_handler.started_typing( target_user=target_user, auth_user=requester.user, room_id=room_id, - timeout=content.get("timeout", 30000), + timeout=timeout, ) else: yield self.typing_handler.stopped_typing( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index ea1f0f7c33..c3108f5181 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -267,10 +267,7 @@ class TypingNotificationsTestCase(unittest.TestCase): from synapse.handlers.typing import RoomMember member = RoomMember(self.room_id, self.u_apple.to_string()) self.handler._member_typing_until[member] = 1002000 - self.handler._member_typing_timer[member] = ( - self.clock.call_later(1002, lambda: 0) - ) - self.handler._room_typing[self.room_id] = set((self.u_apple.to_string(),)) + self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()]) self.assertEquals(self.event_source.get_current_key(), 0) @@ -330,7 +327,7 @@ class TypingNotificationsTestCase(unittest.TestCase): }, }]) - self.clock.advance_time(11) + self.clock.advance_time(16) self.on_new_event.assert_has_calls([ call('typing_key', 2, rooms=[self.room_id]), diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 467f253ef6..a269e6f56e 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -105,9 +105,6 @@ class RoomTypingTestCase(RestTestCase): # Need another user to make notifications actually work yield self.join(self.room_id, user="@jim:red") - def tearDown(self): - self.hs.get_typing_handler().tearDown() - @defer.inlineCallbacks def test_set_typing(self): (code, _) = yield self.mock_resource.trigger( @@ -147,7 +144,7 @@ class RoomTypingTestCase(RestTestCase): self.assertEquals(self.event_source.get_current_key(), 1) - self.clock.advance_time(31) + self.clock.advance_time(36) self.assertEquals(self.event_source.get_current_key(), 2) diff --git a/tests/utils.py b/tests/utils.py index 915b934e94..92d470cb48 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -220,6 +220,7 @@ class MockClock(object): # list of lists of [absolute_time, callback, expired] in no particular # order self.timers = [] + self.loopers = [] def time(self): return self.now @@ -240,7 +241,7 @@ class MockClock(object): return t def looping_call(self, function, interval): - pass + self.loopers.append([function, interval / 1000., self.now]) def cancel_call_later(self, timer, ignore_errs=False): if timer[2]: @@ -269,6 +270,12 @@ class MockClock(object): else: self.timers.append(t) + for looped in self.loopers: + func, interval, last = looped + if last + interval < self.now: + func() + looped[2] = self.now + def advance_time_msec(self, ms): self.advance_time(ms / 1000.) -- cgit 1.4.1 From 655891d179da91206525054eca1aaec562c37e66 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 23 Sep 2016 15:43:34 +0100 Subject: Move FEDERATION_PING_INTERVAL timer. Update log line --- synapse/handlers/typing.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 505a68d142..08313417b2 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -53,7 +53,7 @@ class TypingHandler(object): self.hs = hs self.clock = hs.get_clock() - self.wheel_timer = WheelTimer() + self.wheel_timer = WheelTimer(bucket_size=5000) self.federation = hs.get_replication_layer() @@ -76,7 +76,7 @@ class TypingHandler(object): ) def _handle_timeouts(self): - logger.info("Handling typing timeout") + logger.info("Checking for typing timeouts") now = self.clock.time_msec() @@ -136,12 +136,6 @@ class TypingHandler(object): then=now + timeout, ) - self.wheel_timer.insert( - now=now, - obj=member, - then=now + FEDERATION_PING_INTERVAL, - ) - if was_present: # No point sending another notification defer.returnValue(None) @@ -208,6 +202,14 @@ class TypingHandler(object): def _push_remote(self, member, typing): users = yield self.state.get_current_user_in_room(member.room_id) self._member_last_federation_poke[member] = self.clock.time_msec() + + now = self.clock.time_msec() + self.wheel_timer.insert( + now=now, + obj=member, + then=now + FEDERATION_PING_INTERVAL, + ) + for domain in set(get_domain_from_id(u) for u in users): if domain != self.server_name: self.federation.send_edu( -- cgit 1.4.1 From 3027ea22b066df6282bc6535319725cbfa2704e6 Mon Sep 17 00:00:00 2001 From: Martin Weinelt Date: Wed, 21 Sep 2016 03:13:34 +0200 Subject: Restructure ldap authentication - properly parse return values of ldap bind() calls - externalize authentication methods - change control flow to be more error-resilient - unbind ldap connections in many places - improve log messages and loglevels --- synapse/handlers/auth.py | 279 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 192 insertions(+), 87 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 6986930c0d..3933ce171a 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -31,6 +31,7 @@ import simplejson try: import ldap3 + import ldap3.core.exceptions except ImportError: ldap3 = None pass @@ -504,6 +505,144 @@ class AuthHandler(BaseHandler): raise LoginError(403, "", errcode=Codes.FORBIDDEN) defer.returnValue(user_id) + def _ldap_simple_bind(self, server, localpart, password): + """ Attempt a simple bind with the credentials + given by the user against the LDAP server. + + Returns True, LDAP3Connection + if the bind was successful + Returns False, None + if an error occured + """ + + try: + # bind with the the local users ldap credentials + bind_dn = "{prop}={value},{base}".format( + prop=self.ldap_attributes['uid'], + value=localpart, + base=self.ldap_base + ) + conn = ldap3.Connection(server, bind_dn, password) + logger.debug( + "Established LDAP connection in simple bind mode: %s", + conn + ) + + if self.ldap_start_tls: + conn.start_tls() + logger.debug( + "Upgraded LDAP connection in simple bind mode through StartTLS: %s", + conn + ) + + if conn.bind(): + # GOOD: bind okay + logger.debug("LDAP Bind successful in simple bind mode.") + return True, conn + + # BAD: bind failed + logger.info( + "Binding against LDAP failed for '%s' failed: %s", + localpart, conn.result['description'] + ) + conn.unbind() + return False, None + + except ldap3.core.exceptions.LDAPException as e: + logger.warn("Error during LDAP authentication: %s", e) + return False, None + + def _ldap_authenticated_search(self, server, localpart, password): + """ Attempt to login with the preconfigured bind_dn + and then continue searching and filtering within + the base_dn + + Returns (True, LDAP3Connection) + if a single matching DN within the base was found + that matched the filter expression, and with which + a successful bind was achieved + + The LDAP3Connection returned is the instance that was used to + verify the password not the one using the configured bind_dn. + Returns (False, None) + if an error occured + """ + + try: + conn = ldap3.Connection( + server, + self.ldap_bind_dn, + self.ldap_bind_password + ) + logger.debug( + "Established LDAP connection in search mode: %s", + conn + ) + + if self.ldap_start_tls: + conn.start_tls() + logger.debug( + "Upgraded LDAP connection in search mode through StartTLS: %s", + conn + ) + + if not conn.bind(): + logger.warn( + "Binding against LDAP with `bind_dn` failed: %s", + conn.result['description'] + ) + conn.unbind() + return False, None + + # construct search_filter like (uid=localpart) + query = "({prop}={value})".format( + prop=self.ldap_attributes['uid'], + value=localpart + ) + if self.ldap_filter: + # combine with the AND expression + query = "(&{query}{filter})".format( + query=query, + filter=self.ldap_filter + ) + logger.debug( + "LDAP search filter: %s", + query + ) + conn.search( + search_base=self.ldap_base, + search_filter=query + ) + + if len(conn.response) == 1: + # GOOD: found exactly one result + user_dn = conn.response[0]['dn'] + logger.debug('LDAP search found dn: %s', user_dn) + + # unbind and simple bind with user_dn to verify the password + # Note: do not use rebind(), for some reason it did not verify + # the password for me! + conn.unbind() + return self._ldap_simple_bind(server, localpart, password) + else: + # BAD: found 0 or > 1 results, abort! + if len(conn.response) == 0: + logger.info( + "LDAP search returned no results for '%s'", + localpart + ) + else: + logger.info( + "LDAP search returned too many (%s) results for '%s'", + len(conn.response), localpart + ) + conn.unbind() + return False, None + + except ldap3.core.exceptions.LDAPException as e: + logger.warn("Error during LDAP authentication: %s", e) + return False, None + @defer.inlineCallbacks def _check_ldap_password(self, user_id, password): """ Attempt to authenticate a user against an LDAP Server @@ -516,106 +655,62 @@ class AuthHandler(BaseHandler): if not ldap3 or not self.ldap_enabled: defer.returnValue(False) - if self.ldap_mode not in LDAPMode.LIST: - raise RuntimeError( - 'Invalid ldap mode specified: {mode}'.format( - mode=self.ldap_mode - ) - ) + localpart = UserID.from_string(user_id).localpart try: server = ldap3.Server(self.ldap_uri) logger.debug( - "Attempting ldap connection with %s", + "Attempting LDAP connection with %s", self.ldap_uri ) - localpart = UserID.from_string(user_id).localpart if self.ldap_mode == LDAPMode.SIMPLE: - # bind with the the local users ldap credentials - bind_dn = "{prop}={value},{base}".format( - prop=self.ldap_attributes['uid'], - value=localpart, - base=self.ldap_base + result, conn = self._ldap_simple_bind( + server=server, localpart=localpart, password=password ) - conn = ldap3.Connection(server, bind_dn, password) logger.debug( - "Established ldap connection in simple mode: %s", + 'LDAP authentication method simple bind returned: %s (conn: %s)', + result, conn ) - - if self.ldap_start_tls: - conn.start_tls() - logger.debug( - "Upgraded ldap connection in simple mode through StartTLS: %s", - conn - ) - - conn.bind() - + if not result: + defer.returnValue(False) elif self.ldap_mode == LDAPMode.SEARCH: - # connect with preconfigured credentials and search for local user - conn = ldap3.Connection( - server, - self.ldap_bind_dn, - self.ldap_bind_password + result, conn = self._ldap_authenticated_search( + server=server, localpart=localpart, password=password ) logger.debug( - "Established ldap connection in search mode: %s", + 'LDAP auth method authenticated search returned: %s (conn: %s)', + result, conn ) - - if self.ldap_start_tls: - conn.start_tls() - logger.debug( - "Upgraded ldap connection in search mode through StartTLS: %s", - conn + if not result: + defer.returnValue(False) + else: + raise RuntimeError( + 'Invalid LDAP mode specified: {mode}'.format( + mode=self.ldap_mode ) - - conn.bind() - - # find matching dn - query = "({prop}={value})".format( - prop=self.ldap_attributes['uid'], - value=localpart ) - if self.ldap_filter: - query = "(&{query}{filter})".format( - query=query, - filter=self.ldap_filter - ) - logger.debug("ldap search filter: %s", query) - result = conn.search(self.ldap_base, query) - - if result and len(conn.response) == 1: - # found exactly one result - user_dn = conn.response[0]['dn'] - logger.debug('ldap search found dn: %s', user_dn) - - # unbind and reconnect, rebind with found dn - conn.unbind() - conn = ldap3.Connection( - server, - user_dn, - password, - auto_bind=True - ) - else: - # found 0 or > 1 results, abort! - logger.warn( - "ldap search returned unexpected (%d!=1) amount of results", - len(conn.response) - ) - defer.returnValue(False) - logger.info( - "User authenticated against ldap server: %s", - conn - ) + try: + logger.info( + "User authenticated against LDAP server: %s", + conn + ) + except NameError: + logger.warn("Authentication method yielded no LDAP connection, aborting!") + defer.returnValue(False) + + # check if user with user_id exists + if (yield self.check_user_exists(user_id)): + # exists, authentication complete + conn.unbind() + defer.returnValue(True) - # check for existing account, if none exists, create one - if not (yield self.check_user_exists(user_id)): - # query user metadata for account creation + else: + # does not exist, fetch metadata for account creation from + # existing ldap connection query = "({prop}={value})".format( prop=self.ldap_attributes['uid'], value=localpart @@ -626,9 +721,12 @@ class AuthHandler(BaseHandler): filter=query, user_filter=self.ldap_filter ) - logger.debug("ldap registration filter: %s", query) + logger.debug( + "ldap registration filter: %s", + query + ) - result = conn.search( + conn.search( search_base=self.ldap_base, search_filter=query, attributes=[ @@ -651,20 +749,27 @@ class AuthHandler(BaseHandler): # TODO: bind email, set displayname with data from ldap directory logger.info( - "ldap registration successful: %d: %s (%s, %)", + "Registration based on LDAP data was successful: %d: %s (%s, %)", user_id, localpart, name, mail ) + + defer.returnValue(True) else: - logger.warn( - "ldap registration failed: unexpected (%d!=1) amount of results", - len(conn.response) - ) + if len(conn.response) == 0: + logger.warn("LDAP registration failed, no result.") + else: + logger.warn( + "LDAP registration failed, too many results (%s)", + len(conn.response) + ) + defer.returnValue(False) - defer.returnValue(True) + defer.returnValue(False) + except ldap3.core.exceptions.LDAPException as e: logger.warn("Error during ldap authentication: %s", e) defer.returnValue(False) -- cgit 1.4.1