From ec609f80949972327274142e4eb5b506f55aba1a Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 12 Sep 2016 10:46:02 +0100 Subject: Fix unit tests --- tests/api/test_auth.py | 18 +++++++++--------- tests/handlers/test_typing.py | 3 ++- tests/rest/client/v1/test_register.py | 2 ++ tests/rest/client/v2_alpha/test_register.py | 2 ++ tests/utils.py | 18 ++++++++++++++---- 5 files changed, 29 insertions(+), 14 deletions(-) (limited to 'tests') diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index e91723ca3d..2cf262bb46 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -20,7 +20,7 @@ from mock import Mock from synapse.api.auth import Auth from synapse.api.errors import AuthError from synapse.types import UserID -from tests.utils import setup_test_homeserver +from tests.utils import setup_test_homeserver, mock_getRawHeaders import pymacaroons @@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) @@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -74,7 +74,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_user_by_access_token = Mock(return_value=user_info) request = Mock(args={}) - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) @@ -96,7 +96,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -106,7 +106,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) @@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), masquerading_user_id) @@ -135,7 +135,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] - request.requestHeaders.getRawHeaders = Mock(return_value=[""]) + request.requestHeaders.getRawHeaders = mock_getRawHeaders() d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index ea1f0f7c33..d71ac4838e 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -219,7 +219,8 @@ class TypingNotificationsTestCase(unittest.TestCase): "user_id": self.u_onion.to_string(), "typing": True, } - ) + ), + federation_auth=True, ) self.on_new_event.assert_has_calls([ diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index 4a898a034f..73822f4d8c 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -17,6 +17,7 @@ from synapse.rest.client.v1.register import CreateUserRestServlet from twisted.internet import defer from mock import Mock from tests import unittest +from tests.utils import mock_getRawHeaders import json @@ -30,6 +31,7 @@ class CreateUserServletTestCase(unittest.TestCase): path='/_matrix/client/api/v1/createUser' ) self.request.args = {} + self.request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.appservice = None self.auth = Mock(get_appservice_by_req=Mock( diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 8ac56a1fb2..d7ffd1d427 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -3,6 +3,7 @@ from synapse.api.errors import SynapseError from twisted.internet import defer from mock import Mock from tests import unittest +from tests.utils import mock_getRawHeaders import json @@ -16,6 +17,7 @@ class RegisterRestServletTestCase(unittest.TestCase): path='/_matrix/api/v2_alpha/register' ) self.request.args = {} + self.request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.appservice = None self.auth = Mock(get_appservice_by_req=Mock( diff --git a/tests/utils.py b/tests/utils.py index 915b934e94..1b89470c60 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -115,6 +115,15 @@ def get_mock_call_args(pattern_func, mock_func): return getcallargs(pattern_func, *invoked_args, **invoked_kargs) +def mock_getRawHeaders(headers=None): + headers = headers if headers is not None else {} + + def getRawHeaders(name, default=None): + return headers.get(name, default) + + return getRawHeaders + + # This is a mock /resource/ not an entire server class MockHttpResource(HttpServer): @@ -127,7 +136,7 @@ class MockHttpResource(HttpServer): @patch('twisted.web.http.Request') @defer.inlineCallbacks - def trigger(self, http_method, path, content, mock_request): + def trigger(self, http_method, path, content, mock_request, federation_auth=False): """ Fire an HTTP event. Args: @@ -155,9 +164,10 @@ class MockHttpResource(HttpServer): mock_request.getClientIP.return_value = "-" - mock_request.requestHeaders.getRawHeaders.return_value = [ - "X-Matrix origin=test,key=,sig=" - ] + headers = {} + if federation_auth: + headers["Authorization"] = ["X-Matrix origin=test,key=,sig="] + mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) # return the right path if the event requires it mock_request.path = path -- cgit 1.4.1 From 748d8fdc7bcdb43719e99a48cc74bf078f22396f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 23 Sep 2016 15:31:47 +0100 Subject: Reduce DB hits for replication Some streams will occaisonally advance their positions without actually having any new rows to send over federation. Currently this means that the token will not advance on the workers, leading to them repeatedly sending a slightly out of date token. This in turns requires the master to hit the DB to check if there are any new rows, rather than hitting the no op logic where we check if the given token matches the current token. This commit changes the API to always return an entry if the position for a stream has changed, allowing workers to advance their tokens correctly. --- .gitignore | 8 +- synapse/app/pusher.py | 4 +- synapse/replication/resource.py | 139 +++++++++++++++++++++++-------- synapse/storage/room.py | 3 + tests/replication/slave/storage/_base.py | 3 +- tests/replication/test_resource.py | 3 +- 6 files changed, 115 insertions(+), 45 deletions(-) (limited to 'tests') diff --git a/.gitignore b/.gitignore index f8c4000134..491047c352 100644 --- a/.gitignore +++ b/.gitignore @@ -24,10 +24,10 @@ homeserver*.yaml .coverage htmlcov -demo/*.db -demo/*.log -demo/*.log.* -demo/*.pid +demo/*/*.db +demo/*/*.log +demo/*/*.log.* +demo/*/*.pid demo/media_store.* demo/etc diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index d59f4a571c..1a6f5507a9 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -197,7 +197,7 @@ class PusherServer(HomeServer): yield start_pusher(user_id, app_id, pushkey) stream = results.get("events") - if stream: + if stream and stream["rows"]: min_stream_id = stream["rows"][0][0] max_stream_id = stream["position"] preserve_fn(pusher_pool.on_new_notifications)( @@ -205,7 +205,7 @@ class PusherServer(HomeServer): ) stream = results.get("receipts") - if stream: + if stream and stream["rows"]: rows = stream["rows"] affected_room_ids = set(row[1] for row in rows) min_stream_id = rows[0][0] diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 9aab3ce23c..585bd1c4ad 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -17,6 +17,7 @@ from synapse.http.servlet import parse_integer, parse_string from synapse.http.server import request_handler, finish_request from synapse.replication.pusher_resource import PusherResource from synapse.replication.presence_resource import PresenceResource +from synapse.api.errors import SynapseError from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET @@ -166,7 +167,8 @@ class ReplicationResource(Resource): def replicate(): return self.replicate(request_streams, limit) - result = yield self.notifier.wait_for_replication(replicate, timeout) + writer = yield self.notifier.wait_for_replication(replicate, timeout) + result = writer.finish() for stream_name, stream_content in result.items(): logger.info( @@ -186,6 +188,9 @@ class ReplicationResource(Resource): current_token = yield self.current_replication_token() logger.debug("Replicating up to %r", current_token) + if limit == 0: + raise SynapseError(400, "Limit cannot be 0") + yield self.account_data(writer, current_token, limit, request_streams) yield self.events(writer, current_token, limit, request_streams) # TODO: implement limit @@ -200,7 +205,7 @@ class ReplicationResource(Resource): self.streams(writer, current_token, request_streams) logger.debug("Replicated %d rows", writer.total) - defer.returnValue(writer.finish()) + defer.returnValue(writer) def streams(self, writer, current_token, request_streams): request_token = request_streams.get("streams") @@ -233,31 +238,52 @@ class ReplicationResource(Resource): request_backfill = request_streams.get("backfill") if request_events is not None or request_backfill is not None: - if request_events is None: + if request_backfill is None: request_events = current_token.events if request_backfill is None: request_backfill = current_token.backfill + + no_new_tokens = ( + request_events == current_token.events + and request_backfill == current_token.backfill + ) + if no_new_tokens: + return + res = yield self.store.get_all_new_events( request_backfill, request_events, current_token.backfill, current_token.events, limit ) - writer.write_header_and_rows("events", res.new_forward_events, ( - "position", "internal", "json", "state_group" - )) - writer.write_header_and_rows("backfill", res.new_backfill_events, ( - "position", "internal", "json", "state_group" - )) + + upto_events_token = _position_from_rows( + res.new_forward_events, current_token.events + ) + + upto_backfill_token = _position_from_rows( + res.new_backfill_events, current_token.backfill + ) + + if request_events != upto_events_token: + writer.write_header_and_rows("events", res.new_forward_events, ( + "position", "internal", "json", "state_group" + ), position=upto_events_token) + + if request_backfill != upto_backfill_token: + writer.write_header_and_rows("backfill", res.new_backfill_events, ( + "position", "internal", "json", "state_group", + ), position=upto_backfill_token) + writer.write_header_and_rows( "forward_ex_outliers", res.forward_ex_outliers, - ("position", "event_id", "state_group") + ("position", "event_id", "state_group"), ) writer.write_header_and_rows( "backward_ex_outliers", res.backward_ex_outliers, - ("position", "event_id", "state_group") + ("position", "event_id", "state_group"), ) writer.write_header_and_rows( - "state_resets", res.state_resets, ("position",) + "state_resets", res.state_resets, ("position",), ) @defer.inlineCallbacks @@ -266,15 +292,16 @@ class ReplicationResource(Resource): request_presence = request_streams.get("presence") - if request_presence is not None: + if request_presence is not None and request_presence != current_position: presence_rows = yield self.presence_handler.get_all_presence_updates( request_presence, current_position ) + upto_token = _position_from_rows(presence_rows, current_position) writer.write_header_and_rows("presence", presence_rows, ( "position", "user_id", "state", "last_active_ts", "last_federation_update_ts", "last_user_sync_ts", "status_msg", "currently_active", - )) + ), position=upto_token) @defer.inlineCallbacks def typing(self, writer, current_token, request_streams): @@ -282,7 +309,7 @@ class ReplicationResource(Resource): request_typing = request_streams.get("typing") - if request_typing is not None: + if request_typing is not None and request_typing != current_position: # If they have a higher token than current max, we can assume that # they had been talking to a previous instance of the master. Since # we reset the token on restart, the best (but hacky) thing we can @@ -293,9 +320,10 @@ class ReplicationResource(Resource): typing_rows = yield self.typing_handler.get_all_typing_updates( request_typing, current_position ) + upto_token = _position_from_rows(typing_rows, current_position) writer.write_header_and_rows("typing", typing_rows, ( "position", "room_id", "typing" - )) + ), position=upto_token) @defer.inlineCallbacks def receipts(self, writer, current_token, limit, request_streams): @@ -303,13 +331,14 @@ class ReplicationResource(Resource): request_receipts = request_streams.get("receipts") - if request_receipts is not None: + if request_receipts is not None and request_receipts != current_position: receipts_rows = yield self.store.get_all_updated_receipts( request_receipts, current_position, limit ) + upto_token = _position_from_rows(receipts_rows, current_position) writer.write_header_and_rows("receipts", receipts_rows, ( "position", "room_id", "receipt_type", "user_id", "event_id", "data" - )) + ), position=upto_token) @defer.inlineCallbacks def account_data(self, writer, current_token, limit, request_streams): @@ -324,23 +353,36 @@ class ReplicationResource(Resource): user_account_data = current_position if room_account_data is None: room_account_data = current_position + + no_new_tokens = ( + user_account_data == current_position + and room_account_data == current_position + ) + if no_new_tokens: + return + user_rows, room_rows = yield self.store.get_all_updated_account_data( user_account_data, room_account_data, current_position, limit ) + + upto_users_token = _position_from_rows(user_rows, current_position) + upto_rooms_token = _position_from_rows(room_rows, current_position) + writer.write_header_and_rows("user_account_data", user_rows, ( "position", "user_id", "type", "content" - )) + ), position=upto_users_token) writer.write_header_and_rows("room_account_data", room_rows, ( "position", "user_id", "room_id", "type", "content" - )) + ), position=upto_rooms_token) if tag_account_data is not None: tag_rows = yield self.store.get_all_updated_tags( tag_account_data, current_position, limit ) + upto_tag_token = _position_from_rows(tag_rows, current_position) writer.write_header_and_rows("tag_account_data", tag_rows, ( "position", "user_id", "room_id", "tags" - )) + ), position=upto_tag_token) @defer.inlineCallbacks def push_rules(self, writer, current_token, limit, request_streams): @@ -348,14 +390,15 @@ class ReplicationResource(Resource): push_rules = request_streams.get("push_rules") - if push_rules is not None: + if push_rules is not None and push_rules != current_position: rows = yield self.store.get_all_push_rule_updates( push_rules, current_position, limit ) + upto_token = _position_from_rows(rows, current_position) writer.write_header_and_rows("push_rules", rows, ( "position", "event_stream_ordering", "user_id", "rule_id", "op", "priority_class", "priority", "conditions", "actions" - )) + ), position=upto_token) @defer.inlineCallbacks def pushers(self, writer, current_token, limit, request_streams): @@ -363,18 +406,19 @@ class ReplicationResource(Resource): pushers = request_streams.get("pushers") - if pushers is not None: + if pushers is not None and pushers != current_position: updated, deleted = yield self.store.get_all_updated_pushers( pushers, current_position, limit ) + upto_token = _position_from_rows(updated, current_position) writer.write_header_and_rows("pushers", updated, ( "position", "user_id", "access_token", "profile_tag", "kind", "app_id", "app_display_name", "device_display_name", "pushkey", "ts", "lang", "data" - )) + ), position=upto_token) writer.write_header_and_rows("deleted_pushers", deleted, ( "position", "user_id", "app_id", "pushkey" - )) + ), position=upto_token) @defer.inlineCallbacks def caches(self, writer, current_token, limit, request_streams): @@ -382,13 +426,14 @@ class ReplicationResource(Resource): caches = request_streams.get("caches") - if caches is not None: + if caches is not None and caches != current_position: updated_caches = yield self.store.get_all_updated_caches( caches, current_position, limit ) + upto_token = _position_from_rows(updated_caches, current_position) writer.write_header_and_rows("caches", updated_caches, ( "position", "cache_func", "keys", "invalidation_ts" - )) + ), position=upto_token) @defer.inlineCallbacks def to_device(self, writer, current_token, limit, request_streams): @@ -396,13 +441,14 @@ class ReplicationResource(Resource): to_device = request_streams.get("to_device") - if to_device is not None: + if to_device is not None and to_device != current_position: to_device_rows = yield self.store.get_all_new_device_messages( to_device, current_position, limit ) + upto_token = _position_from_rows(to_device_rows, current_position) writer.write_header_and_rows("to_device", to_device_rows, ( "position", "user_id", "device_id", "message_json" - )) + ), position=upto_token) @defer.inlineCallbacks def public_rooms(self, writer, current_token, limit, request_streams): @@ -410,13 +456,14 @@ class ReplicationResource(Resource): public_rooms = request_streams.get("public_rooms") - if public_rooms is not None: + if public_rooms is not None and public_rooms != current_position: public_rooms_rows = yield self.store.get_all_new_public_rooms( public_rooms, current_position, limit ) + upto_token = _position_from_rows(public_rooms_rows, current_position) writer.write_header_and_rows("public_rooms", public_rooms_rows, ( "position", "room_id", "visibility" - )) + ), position=upto_token) class _Writer(object): @@ -426,11 +473,11 @@ class _Writer(object): self.total = 0 def write_header_and_rows(self, name, rows, fields, position=None): - if not rows: - return - if position is None: - position = rows[-1][0] + if rows: + position = rows[-1][0] + else: + return self.streams[name] = { "position": position if type(position) is int else str(position), @@ -440,6 +487,9 @@ class _Writer(object): self.total += len(rows) + def __nonzero__(self): + return bool(self.total) + def finish(self): return self.streams @@ -461,3 +511,20 @@ class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( def __str__(self): return "_".join(str(value) for value in self) + + +def _position_from_rows(rows, current_position): + """Calculates a position to return for a stream. Ideally we want to return the + position of the last row, as that will be the most correct. However, if there + are no rows we fall back to using the current position to stop us from + repeatedly hitting the storage layer unncessarily thinking there are updates. + (Not all advances of the token correspond to an actual update) + + We can't just always return the current position, as we often limit the + number of rows we replicate, and so the stream may lag. The assumption is + that if the storage layer returns no new rows then we are not lagging and + we are at the `current_position`. + """ + if rows: + return rows[-1][0] + return current_position diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 2ef13d7403..11813b44f6 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -320,6 +320,9 @@ class RoomStore(SQLBaseStore): txn.execute(sql, (prev_id, current_id, limit,)) return txn.fetchall() + if prev_id == current_id: + return defer.succeed([]) + return self.runInteraction( "get_all_new_public_rooms", get_all_new_public_rooms ) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 1f13cd0bc0..b82868054d 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -42,7 +42,8 @@ class BaseSlavedStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def replicate(self): streams = self.slaved_store.stream_positions() - result = yield self.replication.replicate(streams, 100) + writer = yield self.replication.replicate(streams, 100) + result = writer.finish() yield self.slaved_store.process_replication(result) @defer.inlineCallbacks diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index b69832cc1b..f406934a62 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -120,7 +120,7 @@ class ReplicationResourceCase(unittest.TestCase): self.hs.clock.advance_time_msec(1) code, body = yield get self.assertEquals(code, 200) - self.assertEquals(body, {}) + self.assertEquals(body.get("rows", []), []) test_timeout.__name__ = "test_timeout_%s" % (stream) return test_timeout @@ -195,7 +195,6 @@ class ReplicationResourceCase(unittest.TestCase): self.assertIn("field_names", stream) field_names = stream["field_names"] self.assertIn("rows", stream) - self.assertTrue(stream["rows"]) for row in stream["rows"]: self.assertEquals( len(row), len(field_names), -- cgit 1.4.1 From 850b103b36205d2c90da46a0d7413e6033de4f94 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 3 Oct 2016 10:27:10 +0100 Subject: Implement pluggable password auth Allows delegating the password auth to an external module. This also moves the LDAP auth to using this system, allowing it to be removed from the synapse tree entirely in the future. --- synapse/config/homeserver.py | 6 +- synapse/config/ldap.py | 100 -------- synapse/config/password_auth_providers.py | 61 +++++ synapse/handlers/auth.py | 334 ++++----------------------- synapse/util/ldap_auth_provider.py | 368 ++++++++++++++++++++++++++++++ tests/storage/test_appservice.py | 17 +- tests/utils.py | 1 + 7 files changed, 486 insertions(+), 401 deletions(-) delete mode 100644 synapse/config/ldap.py create mode 100644 synapse/config/password_auth_providers.py create mode 100644 synapse/util/ldap_auth_provider.py (limited to 'tests') diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 79b0534b3b..0f890fc04a 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -30,7 +30,7 @@ from .saml2 import SAML2Config from .cas import CasConfig from .password import PasswordConfig from .jwt import JWTConfig -from .ldap import LDAPConfig +from .password_auth_providers import PasswordAuthProviderConfig from .emailconfig import EmailConfig from .workers import WorkerConfig @@ -39,8 +39,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig, - JWTConfig, LDAPConfig, PasswordConfig, EmailConfig, - WorkerConfig,): + JWTConfig, PasswordConfig, EmailConfig, + WorkerConfig, PasswordAuthProviderConfig,): pass diff --git a/synapse/config/ldap.py b/synapse/config/ldap.py deleted file mode 100644 index d83c2230be..0000000000 --- a/synapse/config/ldap.py +++ /dev/null @@ -1,100 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015 Niklas Riekenbrauck -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ._base import Config, ConfigError - - -MISSING_LDAP3 = ( - "Missing ldap3 library. This is required for LDAP Authentication." -) - - -class LDAPMode(object): - SIMPLE = "simple", - SEARCH = "search", - - LIST = (SIMPLE, SEARCH) - - -class LDAPConfig(Config): - def read_config(self, config): - ldap_config = config.get("ldap_config", {}) - - self.ldap_enabled = ldap_config.get("enabled", False) - - if self.ldap_enabled: - # verify dependencies are available - try: - import ldap3 - ldap3 # to stop unused lint - except ImportError: - raise ConfigError(MISSING_LDAP3) - - self.ldap_mode = LDAPMode.SIMPLE - - # verify config sanity - self.require_keys(ldap_config, [ - "uri", - "base", - "attributes", - ]) - - self.ldap_uri = ldap_config["uri"] - self.ldap_start_tls = ldap_config.get("start_tls", False) - self.ldap_base = ldap_config["base"] - self.ldap_attributes = ldap_config["attributes"] - - if "bind_dn" in ldap_config: - self.ldap_mode = LDAPMode.SEARCH - self.require_keys(ldap_config, [ - "bind_dn", - "bind_password", - ]) - - self.ldap_bind_dn = ldap_config["bind_dn"] - self.ldap_bind_password = ldap_config["bind_password"] - self.ldap_filter = ldap_config.get("filter", None) - - # verify attribute lookup - self.require_keys(ldap_config['attributes'], [ - "uid", - "name", - "mail", - ]) - - def require_keys(self, config, required): - missing = [key for key in required if key not in config] - if missing: - raise ConfigError( - "LDAP enabled but missing required config values: {}".format( - ", ".join(missing) - ) - ) - - def default_config(self, **kwargs): - return """\ - # ldap_config: - # enabled: true - # uri: "ldap://ldap.example.com:389" - # start_tls: true - # base: "ou=users,dc=example,dc=com" - # attributes: - # uid: "cn" - # mail: "email" - # name: "givenName" - # #bind_dn: - # #bind_password: - # #filter: "(objectClass=posixAccount)" - """ diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py new file mode 100644 index 0000000000..f6d9bb1c62 --- /dev/null +++ b/synapse/config/password_auth_providers.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 Openmarket +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + +import importlib + + +class PasswordAuthProviderConfig(Config): + def read_config(self, config): + self.password_providers = [] + + # We want to be backwards compatible with the old `ldap_config` + # param. + ldap_config = config.get("ldap_config", {}) + self.ldap_enabled = ldap_config.get("enabled", False) + if self.ldap_enabled: + from synapse.util.ldap_auth_provider import LdapAuthProvider + parsed_config = LdapAuthProvider.parse_config(ldap_config) + self.password_providers.append((LdapAuthProvider, parsed_config)) + + providers = config.get("password_providers", []) + for provider in providers: + # We need to import the module, and then pick the class out of + # that, so we split based on the last dot. + module, clz = provider['module'].rsplit(".", 1) + module = importlib.import_module(module) + provider_class = getattr(module, clz) + + provider_config = provider_class.parse_config(provider["config"]) + self.password_providers.append((provider_class, provider_config)) + + def default_config(self, **kwargs): + return """\ + # password_providers: + # - module: "synapse.util.ldap_auth_provider.LdapAuthProvider" + # config: + # enabled: true + # uri: "ldap://ldap.example.com:389" + # start_tls: true + # base: "ou=users,dc=example,dc=com" + # attributes: + # uid: "cn" + # mail: "email" + # name: "givenName" + # #bind_dn: + # #bind_password: + # #filter: "(objectClass=posixAccount)" + """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 3933ce171a..9583ae1e93 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -20,7 +20,6 @@ from synapse.api.constants import LoginType from synapse.types import UserID from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.util.async import run_on_reactor -from synapse.config.ldap import LDAPMode from twisted.web.client import PartialDownloadError @@ -29,13 +28,6 @@ import bcrypt import pymacaroons import simplejson -try: - import ldap3 - import ldap3.core.exceptions -except ImportError: - ldap3 = None - pass - import synapse.util.stringutils as stringutils @@ -61,21 +53,14 @@ class AuthHandler(BaseHandler): self.sessions = {} self.INVALID_TOKEN_HTTP_STATUS = 401 - self.ldap_enabled = hs.config.ldap_enabled - if self.ldap_enabled: - if not ldap3: - raise RuntimeError( - 'Missing ldap3 library. This is required for LDAP Authentication.' - ) - self.ldap_mode = hs.config.ldap_mode - self.ldap_uri = hs.config.ldap_uri - self.ldap_start_tls = hs.config.ldap_start_tls - self.ldap_base = hs.config.ldap_base - self.ldap_attributes = hs.config.ldap_attributes - if self.ldap_mode == LDAPMode.SEARCH: - self.ldap_bind_dn = hs.config.ldap_bind_dn - self.ldap_bind_password = hs.config.ldap_bind_password - self.ldap_filter = hs.config.ldap_filter + account_handler = _AccountHandler( + hs, check_user_exists=self.check_user_exists + ) + + self.password_providers = [ + module(config=config, account_handler=account_handler) + for module, config in hs.config.password_providers + ] self.hs = hs # FIXME better possibility to access registrationHandler later? self.device_handler = hs.get_device_handler() @@ -477,9 +462,10 @@ class AuthHandler(BaseHandler): Raises: LoginError if the password was incorrect """ - valid_ldap = yield self._check_ldap_password(user_id, password) - if valid_ldap: - defer.returnValue(user_id) + for provider in self.password_providers: + is_valid = yield provider.check_password(user_id, password) + if is_valid: + defer.returnValue(user_id) result = yield self._check_local_password(user_id, password) defer.returnValue(result) @@ -505,275 +491,6 @@ class AuthHandler(BaseHandler): raise LoginError(403, "", errcode=Codes.FORBIDDEN) defer.returnValue(user_id) - def _ldap_simple_bind(self, server, localpart, password): - """ Attempt a simple bind with the credentials - given by the user against the LDAP server. - - Returns True, LDAP3Connection - if the bind was successful - Returns False, None - if an error occured - """ - - try: - # bind with the the local users ldap credentials - bind_dn = "{prop}={value},{base}".format( - prop=self.ldap_attributes['uid'], - value=localpart, - base=self.ldap_base - ) - conn = ldap3.Connection(server, bind_dn, password) - logger.debug( - "Established LDAP connection in simple bind mode: %s", - conn - ) - - if self.ldap_start_tls: - conn.start_tls() - logger.debug( - "Upgraded LDAP connection in simple bind mode through StartTLS: %s", - conn - ) - - if conn.bind(): - # GOOD: bind okay - logger.debug("LDAP Bind successful in simple bind mode.") - return True, conn - - # BAD: bind failed - logger.info( - "Binding against LDAP failed for '%s' failed: %s", - localpart, conn.result['description'] - ) - conn.unbind() - return False, None - - except ldap3.core.exceptions.LDAPException as e: - logger.warn("Error during LDAP authentication: %s", e) - return False, None - - def _ldap_authenticated_search(self, server, localpart, password): - """ Attempt to login with the preconfigured bind_dn - and then continue searching and filtering within - the base_dn - - Returns (True, LDAP3Connection) - if a single matching DN within the base was found - that matched the filter expression, and with which - a successful bind was achieved - - The LDAP3Connection returned is the instance that was used to - verify the password not the one using the configured bind_dn. - Returns (False, None) - if an error occured - """ - - try: - conn = ldap3.Connection( - server, - self.ldap_bind_dn, - self.ldap_bind_password - ) - logger.debug( - "Established LDAP connection in search mode: %s", - conn - ) - - if self.ldap_start_tls: - conn.start_tls() - logger.debug( - "Upgraded LDAP connection in search mode through StartTLS: %s", - conn - ) - - if not conn.bind(): - logger.warn( - "Binding against LDAP with `bind_dn` failed: %s", - conn.result['description'] - ) - conn.unbind() - return False, None - - # construct search_filter like (uid=localpart) - query = "({prop}={value})".format( - prop=self.ldap_attributes['uid'], - value=localpart - ) - if self.ldap_filter: - # combine with the AND expression - query = "(&{query}{filter})".format( - query=query, - filter=self.ldap_filter - ) - logger.debug( - "LDAP search filter: %s", - query - ) - conn.search( - search_base=self.ldap_base, - search_filter=query - ) - - if len(conn.response) == 1: - # GOOD: found exactly one result - user_dn = conn.response[0]['dn'] - logger.debug('LDAP search found dn: %s', user_dn) - - # unbind and simple bind with user_dn to verify the password - # Note: do not use rebind(), for some reason it did not verify - # the password for me! - conn.unbind() - return self._ldap_simple_bind(server, localpart, password) - else: - # BAD: found 0 or > 1 results, abort! - if len(conn.response) == 0: - logger.info( - "LDAP search returned no results for '%s'", - localpart - ) - else: - logger.info( - "LDAP search returned too many (%s) results for '%s'", - len(conn.response), localpart - ) - conn.unbind() - return False, None - - except ldap3.core.exceptions.LDAPException as e: - logger.warn("Error during LDAP authentication: %s", e) - return False, None - - @defer.inlineCallbacks - def _check_ldap_password(self, user_id, password): - """ Attempt to authenticate a user against an LDAP Server - and register an account if none exists. - - Returns: - True if authentication against LDAP was successful - """ - - if not ldap3 or not self.ldap_enabled: - defer.returnValue(False) - - localpart = UserID.from_string(user_id).localpart - - try: - server = ldap3.Server(self.ldap_uri) - logger.debug( - "Attempting LDAP connection with %s", - self.ldap_uri - ) - - if self.ldap_mode == LDAPMode.SIMPLE: - result, conn = self._ldap_simple_bind( - server=server, localpart=localpart, password=password - ) - logger.debug( - 'LDAP authentication method simple bind returned: %s (conn: %s)', - result, - conn - ) - if not result: - defer.returnValue(False) - elif self.ldap_mode == LDAPMode.SEARCH: - result, conn = self._ldap_authenticated_search( - server=server, localpart=localpart, password=password - ) - logger.debug( - 'LDAP auth method authenticated search returned: %s (conn: %s)', - result, - conn - ) - if not result: - defer.returnValue(False) - else: - raise RuntimeError( - 'Invalid LDAP mode specified: {mode}'.format( - mode=self.ldap_mode - ) - ) - - try: - logger.info( - "User authenticated against LDAP server: %s", - conn - ) - except NameError: - logger.warn("Authentication method yielded no LDAP connection, aborting!") - defer.returnValue(False) - - # check if user with user_id exists - if (yield self.check_user_exists(user_id)): - # exists, authentication complete - conn.unbind() - defer.returnValue(True) - - else: - # does not exist, fetch metadata for account creation from - # existing ldap connection - query = "({prop}={value})".format( - prop=self.ldap_attributes['uid'], - value=localpart - ) - - if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter: - query = "(&{filter}{user_filter})".format( - filter=query, - user_filter=self.ldap_filter - ) - logger.debug( - "ldap registration filter: %s", - query - ) - - conn.search( - search_base=self.ldap_base, - search_filter=query, - attributes=[ - self.ldap_attributes['name'], - self.ldap_attributes['mail'] - ] - ) - - if len(conn.response) == 1: - attrs = conn.response[0]['attributes'] - mail = attrs[self.ldap_attributes['mail']][0] - name = attrs[self.ldap_attributes['name']][0] - - # create account - registration_handler = self.hs.get_handlers().registration_handler - user_id, access_token = ( - yield registration_handler.register(localpart=localpart) - ) - - # TODO: bind email, set displayname with data from ldap directory - - logger.info( - "Registration based on LDAP data was successful: %d: %s (%s, %)", - user_id, - localpart, - name, - mail - ) - - defer.returnValue(True) - else: - if len(conn.response) == 0: - logger.warn("LDAP registration failed, no result.") - else: - logger.warn( - "LDAP registration failed, too many results (%s)", - len(conn.response) - ) - - defer.returnValue(False) - - defer.returnValue(False) - - except ldap3.core.exceptions.LDAPException as e: - logger.warn("Error during ldap authentication: %s", e) - defer.returnValue(False) - @defer.inlineCallbacks def issue_access_token(self, user_id, device_id=None): access_token = self.generate_access_token(user_id) @@ -911,3 +628,30 @@ class AuthHandler(BaseHandler): stored_hash.encode('utf-8')) == stored_hash else: return False + + +class _AccountHandler(object): + """A proxy object that gets passed to password auth providers so they + can register new users etc if necessary. + """ + def __init__(self, hs, check_user_exists): + self.hs = hs + + self._check_user_exists = check_user_exists + + def check_user_exists(self, user_id): + """Check if user exissts. + + Returns: + Deferred(bool) + """ + return self._check_user_exists(user_id) + + def register(self, localpart): + """Registers a new user with given localpart + + Returns: + Deferred: a 2-tuple of (user_id, access_token) + """ + reg = self.hs.get_handlers().registration_handler + return reg.register(localpart=localpart) diff --git a/synapse/util/ldap_auth_provider.py b/synapse/util/ldap_auth_provider.py new file mode 100644 index 0000000000..f852e9b037 --- /dev/null +++ b/synapse/util/ldap_auth_provider.py @@ -0,0 +1,368 @@ + +from twisted.internet import defer + +from synapse.config._base import ConfigError +from synapse.types import UserID + +import ldap3 +import ldap3.core.exceptions + +import logging + +try: + import ldap3 + import ldap3.core.exceptions +except ImportError: + ldap3 = None + pass + + +logger = logging.getLogger(__name__) + + +class LDAPMode(object): + SIMPLE = "simple", + SEARCH = "search", + + LIST = (SIMPLE, SEARCH) + + +class LdapAuthProvider(object): + __version__ = "0.1" + + def __init__(self, config, account_handler): + self.account_handler = account_handler + + if not ldap3: + raise RuntimeError( + 'Missing ldap3 library. This is required for LDAP Authentication.' + ) + + self.ldap_mode = config.mode + self.ldap_uri = config.uri + self.ldap_start_tls = config.start_tls + self.ldap_base = config.base + self.ldap_attributes = config.attributes + if self.ldap_mode == LDAPMode.SEARCH: + self.ldap_bind_dn = config.bind_dn + self.ldap_bind_password = config.bind_password + self.ldap_filter = config.filter + + @defer.inlineCallbacks + def check_password(self, user_id, password): + """ Attempt to authenticate a user against an LDAP Server + and register an account if none exists. + + Returns: + True if authentication against LDAP was successful + """ + localpart = UserID.from_string(user_id).localpart + + try: + server = ldap3.Server(self.ldap_uri) + logger.debug( + "Attempting LDAP connection with %s", + self.ldap_uri + ) + + if self.ldap_mode == LDAPMode.SIMPLE: + result, conn = self._ldap_simple_bind( + server=server, localpart=localpart, password=password + ) + logger.debug( + 'LDAP authentication method simple bind returned: %s (conn: %s)', + result, + conn + ) + if not result: + defer.returnValue(False) + elif self.ldap_mode == LDAPMode.SEARCH: + result, conn = self._ldap_authenticated_search( + server=server, localpart=localpart, password=password + ) + logger.debug( + 'LDAP auth method authenticated search returned: %s (conn: %s)', + result, + conn + ) + if not result: + defer.returnValue(False) + else: + raise RuntimeError( + 'Invalid LDAP mode specified: {mode}'.format( + mode=self.ldap_mode + ) + ) + + try: + logger.info( + "User authenticated against LDAP server: %s", + conn + ) + except NameError: + logger.warn( + "Authentication method yielded no LDAP connection, aborting!" + ) + defer.returnValue(False) + + # check if user with user_id exists + if (yield self.account_handler.check_user_exists(user_id)): + # exists, authentication complete + conn.unbind() + defer.returnValue(True) + + else: + # does not exist, fetch metadata for account creation from + # existing ldap connection + query = "({prop}={value})".format( + prop=self.ldap_attributes['uid'], + value=localpart + ) + + if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter: + query = "(&{filter}{user_filter})".format( + filter=query, + user_filter=self.ldap_filter + ) + logger.debug( + "ldap registration filter: %s", + query + ) + + conn.search( + search_base=self.ldap_base, + search_filter=query, + attributes=[ + self.ldap_attributes['name'], + self.ldap_attributes['mail'] + ] + ) + + if len(conn.response) == 1: + attrs = conn.response[0]['attributes'] + mail = attrs[self.ldap_attributes['mail']][0] + name = attrs[self.ldap_attributes['name']][0] + + # create account + user_id, access_token = ( + yield self.account_handler.register(localpart=localpart) + ) + + # TODO: bind email, set displayname with data from ldap directory + + logger.info( + "Registration based on LDAP data was successful: %d: %s (%s, %)", + user_id, + localpart, + name, + mail + ) + + defer.returnValue(True) + else: + if len(conn.response) == 0: + logger.warn("LDAP registration failed, no result.") + else: + logger.warn( + "LDAP registration failed, too many results (%s)", + len(conn.response) + ) + + defer.returnValue(False) + + defer.returnValue(False) + + except ldap3.core.exceptions.LDAPException as e: + logger.warn("Error during ldap authentication: %s", e) + defer.returnValue(False) + + @staticmethod + def parse_config(config): + class _LdapConfig(object): + pass + + ldap_config = _LdapConfig() + + ldap_config.enabled = config.get("enabled", False) + + ldap_config.mode = LDAPMode.SIMPLE + + # verify config sanity + _require_keys(config, [ + "uri", + "base", + "attributes", + ]) + + ldap_config.uri = config["uri"] + ldap_config.start_tls = config.get("start_tls", False) + ldap_config.base = config["base"] + ldap_config.attributes = config["attributes"] + + if "bind_dn" in config: + ldap_config.mode = LDAPMode.SEARCH + _require_keys(config, [ + "bind_dn", + "bind_password", + ]) + + ldap_config.bind_dn = config["bind_dn"] + ldap_config.bind_password = config["bind_password"] + ldap_config.filter = config.get("filter", None) + + # verify attribute lookup + _require_keys(config['attributes'], [ + "uid", + "name", + "mail", + ]) + + return ldap_config + + def _ldap_simple_bind(self, server, localpart, password): + """ Attempt a simple bind with the credentials + given by the user against the LDAP server. + + Returns True, LDAP3Connection + if the bind was successful + Returns False, None + if an error occured + """ + + try: + # bind with the the local users ldap credentials + bind_dn = "{prop}={value},{base}".format( + prop=self.ldap_attributes['uid'], + value=localpart, + base=self.ldap_base + ) + conn = ldap3.Connection(server, bind_dn, password) + logger.debug( + "Established LDAP connection in simple bind mode: %s", + conn + ) + + if self.ldap_start_tls: + conn.start_tls() + logger.debug( + "Upgraded LDAP connection in simple bind mode through StartTLS: %s", + conn + ) + + if conn.bind(): + # GOOD: bind okay + logger.debug("LDAP Bind successful in simple bind mode.") + return True, conn + + # BAD: bind failed + logger.info( + "Binding against LDAP failed for '%s' failed: %s", + localpart, conn.result['description'] + ) + conn.unbind() + return False, None + + except ldap3.core.exceptions.LDAPException as e: + logger.warn("Error during LDAP authentication: %s", e) + return False, None + + def _ldap_authenticated_search(self, server, localpart, password): + """ Attempt to login with the preconfigured bind_dn + and then continue searching and filtering within + the base_dn + + Returns (True, LDAP3Connection) + if a single matching DN within the base was found + that matched the filter expression, and with which + a successful bind was achieved + + The LDAP3Connection returned is the instance that was used to + verify the password not the one using the configured bind_dn. + Returns (False, None) + if an error occured + """ + + try: + conn = ldap3.Connection( + server, + self.ldap_bind_dn, + self.ldap_bind_password + ) + logger.debug( + "Established LDAP connection in search mode: %s", + conn + ) + + if self.ldap_start_tls: + conn.start_tls() + logger.debug( + "Upgraded LDAP connection in search mode through StartTLS: %s", + conn + ) + + if not conn.bind(): + logger.warn( + "Binding against LDAP with `bind_dn` failed: %s", + conn.result['description'] + ) + conn.unbind() + return False, None + + # construct search_filter like (uid=localpart) + query = "({prop}={value})".format( + prop=self.ldap_attributes['uid'], + value=localpart + ) + if self.ldap_filter: + # combine with the AND expression + query = "(&{query}{filter})".format( + query=query, + filter=self.ldap_filter + ) + logger.debug( + "LDAP search filter: %s", + query + ) + conn.search( + search_base=self.ldap_base, + search_filter=query + ) + + if len(conn.response) == 1: + # GOOD: found exactly one result + user_dn = conn.response[0]['dn'] + logger.debug('LDAP search found dn: %s', user_dn) + + # unbind and simple bind with user_dn to verify the password + # Note: do not use rebind(), for some reason it did not verify + # the password for me! + conn.unbind() + return self._ldap_simple_bind(server, localpart, password) + else: + # BAD: found 0 or > 1 results, abort! + if len(conn.response) == 0: + logger.info( + "LDAP search returned no results for '%s'", + localpart + ) + else: + logger.info( + "LDAP search returned too many (%s) results for '%s'", + len(conn.response), localpart + ) + conn.unbind() + return False, None + + except ldap3.core.exceptions.LDAPException as e: + logger.warn("Error during LDAP authentication: %s", e) + return False, None + + +def _require_keys(config, required): + missing = [key for key in required if key not in config] + if missing: + raise ConfigError( + "LDAP enabled but missing required config values: {}".format( + ", ".join(missing) + ) + ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 3e2862daae..dfe2338554 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -37,6 +37,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): config = Mock( app_service_config_files=self.as_yaml_files, event_cache_size=1, + password_providers=[], ) hs = yield setup_test_homeserver(config=config) @@ -112,6 +113,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): config = Mock( app_service_config_files=self.as_yaml_files, event_cache_size=1, + password_providers=[], ) hs = yield setup_test_homeserver(config=config) self.db_pool = hs.get_db_pool() @@ -440,7 +442,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f1 = self._write_config(suffix="1") f2 = self._write_config(suffix="2") - config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) + config = Mock( + app_service_config_files=[f1, f2], event_cache_size=1, + password_providers=[] + ) hs = yield setup_test_homeserver(config=config, datastore=Mock()) ApplicationServiceStore(hs) @@ -450,7 +455,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f1 = self._write_config(id="id", suffix="1") f2 = self._write_config(id="id", suffix="2") - config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) + config = Mock( + app_service_config_files=[f1, f2], event_cache_size=1, + password_providers=[] + ) hs = yield setup_test_homeserver(config=config, datastore=Mock()) with self.assertRaises(ConfigError) as cm: @@ -466,7 +474,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f1 = self._write_config(as_token="as_token", suffix="1") f2 = self._write_config(as_token="as_token", suffix="2") - config = Mock(app_service_config_files=[f1, f2], event_cache_size=1) + config = Mock( + app_service_config_files=[f1, f2], event_cache_size=1, + password_providers=[] + ) hs = yield setup_test_homeserver(config=config, datastore=Mock()) with self.assertRaises(ConfigError) as cm: diff --git a/tests/utils.py b/tests/utils.py index 92d470cb48..f74526b6a7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -52,6 +52,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.server_name = name config.trusted_third_party_id_servers = [] config.room_invite_state_types = [] + config.password_providers = [] config.use_frozen_dicts = True config.database_config = {"name": "sqlite3"} -- cgit 1.4.1 From 9bfc61779111624e972939491a0b5c02190d3463 Mon Sep 17 00:00:00 2001 From: Patrik Oldsberg Date: Thu, 6 Oct 2016 10:43:32 +0200 Subject: storage/appservice: make appservice methods only relying on the cache synchronous --- synapse/api/auth.py | 7 +++---- synapse/handlers/appservice.py | 20 +++++++++----------- synapse/handlers/directory.py | 11 ++++------- synapse/handlers/register.py | 5 ++--- synapse/handlers/room.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/rest/client/v1/register.py | 2 +- synapse/storage/appservice.py | 12 ++++++------ tests/rest/client/v2_alpha/test_register.py | 2 +- tests/storage/test_appservice.py | 9 +++------ 10 files changed, 31 insertions(+), 41 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e75fd518be..27599124d2 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -653,7 +653,7 @@ class Auth(object): @defer.inlineCallbacks def _get_appservice_user_id(self, request): - app_service = yield self.store.get_app_service_by_token( + app_service = self.store.get_app_service_by_token( get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS ) @@ -855,13 +855,12 @@ class Auth(object): } defer.returnValue(user_info) - @defer.inlineCallbacks def get_appservice_by_req(self, request): try: token = get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS ) - service = yield self.store.get_app_service_by_token(token) + service = self.store.get_app_service_by_token(token) if not service: logger.warn("Unrecognised appservice access token: %s" % (token,)) raise AuthError( @@ -870,7 +869,7 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) request.authenticated_entity = service.sender - defer.returnValue(service) + return defer.succeed(service) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token." diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 88fa0bb2e4..05af54d31b 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -59,7 +59,7 @@ class ApplicationServicesHandler(object): Args: current_id(int): The current maximum ID. """ - services = yield self.store.get_app_services() + services = self.store.get_app_services() if not services or not self.notify_appservices: return @@ -142,7 +142,7 @@ class ApplicationServicesHandler(object): association can be found. """ room_alias_str = room_alias.to_string() - services = yield self.store.get_app_services() + services = self.store.get_app_services() alias_query_services = [ s for s in services if ( s.is_interested_in_alias(room_alias_str) @@ -177,7 +177,7 @@ class ApplicationServicesHandler(object): @defer.inlineCallbacks def get_3pe_protocols(self, only_protocol=None): - services = yield self.store.get_app_services() + services = self.store.get_app_services() protocols = {} # Collect up all the individual protocol responses out of the ASes @@ -224,7 +224,7 @@ class ApplicationServicesHandler(object): list: A list of services interested in this event based on the service regex. """ - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_list = [ s for s in services if ( yield s.is_interested(event, self.store) @@ -232,23 +232,21 @@ class ApplicationServicesHandler(object): ] defer.returnValue(interested_list) - @defer.inlineCallbacks def _get_services_for_user(self, user_id): - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_list = [ s for s in services if ( s.is_interested_in_user(user_id) ) ] - defer.returnValue(interested_list) + return defer.succeed(interested_list) - @defer.inlineCallbacks def _get_services_for_3pn(self, protocol): - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_list = [ s for s in services if s.is_interested_in_protocol(protocol) ] - defer.returnValue(interested_list) + return defer.succeed(interested_list) @defer.inlineCallbacks def _is_unknown_user(self, user_id): @@ -264,7 +262,7 @@ class ApplicationServicesHandler(object): return # user not found; could be the AS though, so check. - services = yield self.store.get_app_services() + services = self.store.get_app_services() service_list = [s for s in services if s.sender == user_id] defer.returnValue(len(service_list) == 0) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 14352985e2..c00274afc3 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -288,13 +288,12 @@ class DirectoryHandler(BaseHandler): result = yield as_handler.query_room_alias_exists(room_alias) defer.returnValue(result) - @defer.inlineCallbacks def can_modify_alias(self, alias, user_id=None): # Any application service "interested" in an alias they are regexing on # can modify the alias. # Users can only modify the alias if ALL the interested services have # non-exclusive locks on the alias (or there are no interested services) - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_services = [ s for s in services if s.is_interested_in_alias(alias.to_string()) ] @@ -302,14 +301,12 @@ class DirectoryHandler(BaseHandler): for service in interested_services: if user_id == service.sender: # this user IS the app service so they can do whatever they like - defer.returnValue(True) - return + return defer.succeed(True) elif service.is_exclusive_alias(alias.to_string()): # another service has an exclusive lock on this alias. - defer.returnValue(False) - return + return defer.succeed(False) # either no interested services, or no service with an exclusive lock - defer.returnValue(True) + return defer.succeed(True) @defer.inlineCallbacks def _user_can_delete_alias(self, alias, user_id): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index dd75c4fecf..19329057d5 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -194,7 +194,7 @@ class RegistrationHandler(BaseHandler): def appservice_register(self, user_localpart, as_token): user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() - service = yield self.store.get_app_service_by_token(as_token) + service = self.store.get_app_service_by_token(as_token) if not service: raise AuthError(403, "Invalid application service token.") if not service.is_interested_in_user(user_id): @@ -305,11 +305,10 @@ class RegistrationHandler(BaseHandler): # XXX: This should be a deferred list, shouldn't it? yield identity_handler.bind_threepid(c, user_id) - @defer.inlineCallbacks def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): # valid user IDs must not clash with any user ID namespaces claimed by # application services. - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_services = [ s for s in services if s.is_interested_in_user(user_id) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index cbd26f8f95..a7f533f7be 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -437,7 +437,7 @@ class RoomEventSource(object): logger.warn("Stream has topological part!!!! %r", from_key) from_key = "s%s" % (from_token.stream,) - app_service = yield self.store.get_app_service_by_user_id( + app_service = self.store.get_app_service_by_user_id( user.to_string() ) if app_service: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b5962f4f5a..1f910ff814 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -788,7 +788,7 @@ class SyncHandler(object): assert since_token - app_service = yield self.store.get_app_service_by_user_id(user_id) + app_service = self.store.get_app_service_by_user_id(user_id) if app_service: rooms = yield self.store.get_app_service_rooms(app_service) joined_room_ids = set(r.room_id for r in rooms) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 3046da7aec..fe4480b363 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -391,7 +391,7 @@ class CreateUserRestServlet(ClientV1RestServlet): user_json = parse_json_object_from_request(request) access_token = get_access_token_from_request(request) - app_service = yield self.store.get_app_service_by_token( + app_service = self.store.get_app_service_by_token( access_token ) if not app_service: diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index a854a87eab..3d5994a580 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -37,7 +37,7 @@ class ApplicationServiceStore(SQLBaseStore): ) def get_app_services(self): - return defer.succeed(self.services_cache) + return self.services_cache def get_app_service_by_user_id(self, user_id): """Retrieve an application service from their user ID. @@ -54,8 +54,8 @@ class ApplicationServiceStore(SQLBaseStore): """ for service in self.services_cache: if service.sender == user_id: - return defer.succeed(service) - return defer.succeed(None) + return service + return None def get_app_service_by_token(self, token): """Get the application service with the given appservice token. @@ -67,8 +67,8 @@ class ApplicationServiceStore(SQLBaseStore): """ for service in self.services_cache: if service.token == token: - return defer.succeed(service) - return defer.succeed(None) + return service + return None def get_app_service_rooms(self, service): """Get a list of RoomsForUser for this application service. @@ -163,7 +163,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore): ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore - as_list = yield self.get_app_services() + as_list = self.get_app_services() services = [] for res in results: diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 8ac56a1fb2..e9cb416e4b 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -19,7 +19,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.appservice = None self.auth = Mock(get_appservice_by_req=Mock( - side_effect=lambda x: defer.succeed(self.appservice)) + side_effect=lambda x: self.appservice) ) self.auth_result = (False, None, None, None) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 3e2862daae..f3df8302da 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -71,14 +71,12 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) - @defer.inlineCallbacks def test_retrieve_unknown_service_token(self): - service = yield self.store.get_app_service_by_token("invalid_token") + service = self.store.get_app_service_by_token("invalid_token") self.assertEquals(service, None) - @defer.inlineCallbacks def test_retrieval_of_service(self): - stored_service = yield self.store.get_app_service_by_token( + stored_service = self.store.get_app_service_by_token( self.as_token ) self.assertEquals(stored_service.token, self.as_token) @@ -97,9 +95,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): [] ) - @defer.inlineCallbacks def test_retrieval_of_all_services(self): - services = yield self.store.get_app_services() + services = self.store.get_app_services() self.assertEquals(len(services), 3) -- cgit 1.4.1 From 7b5546d0776d8d238f4c8775ea8878d92ecc2527 Mon Sep 17 00:00:00 2001 From: Patrik Oldsberg Date: Tue, 4 Oct 2016 22:32:58 +0200 Subject: rest/client/v1/register: use the correct requester in createUser Signed-off-by: Patrik Oldsberg --- synapse/handlers/register.py | 6 ++---- synapse/rest/client/v1/register.py | 9 ++++++--- tests/handlers/test_register.py | 8 +++++--- tests/rest/client/v1/test_register.py | 30 +++++++++--------------------- 4 files changed, 22 insertions(+), 31 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 19329057d5..7e119f13b1 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -19,7 +19,6 @@ import urllib from twisted.internet import defer -import synapse.types from synapse.api.errors import ( AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError ) @@ -370,7 +369,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, localpart, displayname, duration_in_ms, + def get_or_create_user(self, requester, localpart, displayname, duration_in_ms, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -417,9 +416,8 @@ class RegistrationHandler(BaseHandler): if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) profile_handler = self.hs.get_handlers().profile_handler - requester = synapse.types.create_requester(user) yield profile_handler.set_displayname( - user, requester, displayname + user, requester, displayname, by_admin=True, ) defer.returnValue((user_id, token)) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index fe4480b363..b5a76fefac 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -22,6 +22,7 @@ from synapse.api.auth import get_access_token_from_request from .base import ClientV1RestServlet, client_path_patterns import synapse.util.stringutils as stringutils from synapse.http.servlet import parse_json_object_from_request +from synapse.types import create_requester from synapse.util.async import run_on_reactor @@ -397,9 +398,10 @@ class CreateUserRestServlet(ClientV1RestServlet): if not app_service: raise SynapseError(403, "Invalid application service token.") - logger.debug("creating user: %s", user_json) + requester = create_requester(app_service.sender) - response = yield self._do_create(user_json) + logger.debug("creating user: %s", user_json) + response = yield self._do_create(requester, user_json) defer.returnValue((200, response)) @@ -407,7 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet): return 403, {} @defer.inlineCallbacks - def _do_create(self, user_json): + def _do_create(self, requester, user_json): yield run_on_reactor() if "localpart" not in user_json: @@ -433,6 +435,7 @@ class CreateUserRestServlet(ClientV1RestServlet): handler = self.handlers.registration_handler user_id, token = yield handler.get_or_create_user( + requester=requester, localpart=localpart, displayname=displayname, duration_in_ms=(duration_seconds * 1000), diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index a7de3c7c17..9c9d144690 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,7 +17,7 @@ from twisted.internet import defer from .. import unittest from synapse.handlers.register import RegistrationHandler -from synapse.types import UserID +from synapse.types import UserID, create_requester from tests.utils import setup_test_homeserver @@ -57,8 +57,9 @@ class RegistrationTestCase(unittest.TestCase): local_part = "someone" display_name = "someone" user_id = "@someone:test" + requester = create_requester("@as:test") result_user_id, result_token = yield self.handler.get_or_create_user( - local_part, display_name, duration_ms) + requester, local_part, display_name, duration_ms) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') @@ -74,7 +75,8 @@ class RegistrationTestCase(unittest.TestCase): local_part = "frank" display_name = "Frank" user_id = "@frank:test" + requester = create_requester("@as:test") result_user_id, result_token = yield self.handler.get_or_create_user( - local_part, display_name, duration_ms) + requester, local_part, display_name, duration_ms) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index 4a898a034f..44ba9ff58f 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -31,33 +31,21 @@ class CreateUserServletTestCase(unittest.TestCase): ) self.request.args = {} - self.appservice = None - self.auth = Mock(get_appservice_by_req=Mock( - side_effect=lambda x: defer.succeed(self.appservice)) - ) + self.registration_handler = Mock() - self.auth_result = (False, None, None, None) - self.auth_handler = Mock( - check_auth=Mock(side_effect=lambda x, y, z: self.auth_result), - get_session_data=Mock(return_value=None) + self.appservice = Mock(sender="@as:test") + self.datastore = Mock( + get_app_service_by_token=Mock(return_value=self.appservice) ) - self.registration_handler = Mock() - self.identity_handler = Mock() - self.login_handler = Mock() - # do the dance to hook it up to the hs global - self.handlers = Mock( - auth_handler=self.auth_handler, + # do the dance to hook things up to the hs global + handlers = Mock( registration_handler=self.registration_handler, - identity_handler=self.identity_handler, - login_handler=self.login_handler ) self.hs = Mock() - self.hs.hostname = "supergbig~testing~thing.com" - self.hs.get_auth = Mock(return_value=self.auth) - self.hs.get_handlers = Mock(return_value=self.handlers) - self.hs.config.enable_registration = True - # init the thing we're testing + self.hs.hostname = "superbig~testing~thing.com" + self.hs.get_datastore = Mock(return_value=self.datastore) + self.hs.get_handlers = Mock(return_value=handlers) self.servlet = CreateUserRestServlet(self.hs) @defer.inlineCallbacks -- cgit 1.4.1 From d9350b0db846dfe996971797052763428739f3ad Mon Sep 17 00:00:00 2001 From: Alexander Maznev Date: Mon, 10 Oct 2016 07:25:26 -0500 Subject: Error codes for filters * add tests Signed-off-by: Alexander Maznev --- synapse/rest/client/v2_alpha/filter.py | 8 ++++---- tests/rest/client/v2_alpha/test_filter.py | 15 ++++++++++++--- 2 files changed, 16 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 510f8b2c74..f7758fc68c 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import AuthError, SynapseError, StoreError, Codes from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID @@ -45,7 +45,7 @@ class GetFilterRestServlet(RestServlet): raise AuthError(403, "Cannot get filters for other users") if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only get filters for local users") + raise AuthError(403, "Can only get filters for local users") try: filter_id = int(filter_id) @@ -59,8 +59,8 @@ class GetFilterRestServlet(RestServlet): ) defer.returnValue((200, filter.get_filter_json())) - except KeyError: - raise SynapseError(400, "No such filter") + except (KeyError, StoreError): + raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND) class CreateFilterRestServlet(RestServlet): diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index d1442aafac..47ca5e8c8a 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -19,7 +19,7 @@ from . import V2AlphaRestTestCase from synapse.rest.client.v2_alpha import filter -from synapse.api.errors import StoreError +from synapse.api.errors import StoreError, Codes class FilterTestCase(V2AlphaRestTestCase): @@ -82,11 +82,20 @@ class FilterTestCase(V2AlphaRestTestCase): (code, response) = yield self.mock_resource.trigger_get( "/user/%s/filter/2" % (self.USER_ID) ) - self.assertEquals(404, code) + self.assertEquals(400, code) @defer.inlineCallbacks def test_get_filter_no_user(self): (code, response) = yield self.mock_resource.trigger_get( "/user/%s/filter/0" % (self.USER_ID) ) - self.assertEquals(404, code) + self.assertEquals(400, code) + self.assertEquals(response['errcode'], Codes.FORBIDDEN) + + @defer.inlineCallbacks + def test_get_filter_missing_id(self): + (code, response) = yield self.mock_resource.trigger_get( + "/user/%s/filter/0" % (self.USER_ID) + ) + self.assertEquals(400, code) + self.assertEquals(response['errcode'], Codes.NOT_FOUND) -- cgit 1.4.1 From d43b63818c174cfaa41fab13334b159c155dd2da Mon Sep 17 00:00:00 2001 From: pik Date: Fri, 14 Oct 2016 15:46:54 -0500 Subject: Fix MockHttpRequest always returning M_UNKNOWN errcode in testing --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/utils.py b/tests/utils.py index 92d470cb48..5893145a16 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -188,7 +188,7 @@ class MockHttpResource(HttpServer): ) defer.returnValue((code, response)) except CodeMessageException as e: - defer.returnValue((e.code, cs_error(e.msg))) + defer.returnValue((e.code, cs_error(e.msg, code=e.errcode))) raise KeyError("No event can handle %s" % path) -- cgit 1.4.1 From e8b1d2a45200d88a576360a83e3dff1aac4ad679 Mon Sep 17 00:00:00 2001 From: pik Date: Tue, 18 Oct 2016 12:17:38 -0500 Subject: Refactor test_filter to use real DataStore * add tests for filter api errors --- synapse/rest/client/v2_alpha/filter.py | 4 +- synapse/storage/_base.py | 1 - tests/rest/client/v2_alpha/test_filter.py | 124 +++++++++++++++++++----------- 3 files changed, 83 insertions(+), 46 deletions(-) (limited to 'tests') diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index f7758fc68c..b4084fec62 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -74,6 +74,7 @@ class CreateFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): + target_user = UserID.from_string(user_id) requester = yield self.auth.get_user_by_req(request) @@ -81,10 +82,9 @@ class CreateFilterRestServlet(RestServlet): raise AuthError(403, "Cannot create filters for other users") if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only create filters for local users") + raise AuthError(403, "Can only create filters for local users") content = parse_json_object_from_request(request) - filter_id = yield self.filtering.add_user_filter( user_localpart=target_user.localpart, user_filter=content, diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 49fa8614f2..d828d6ee1d 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -85,7 +85,6 @@ class LoggingTransaction(object): sql_logger.debug("[SQL] {%s} %s", self.name, sql) sql = self.database_engine.convert_param_style(sql) - if args: try: sql_logger.debug( diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index 47ca5e8c8a..3d27d03cbf 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -15,87 +15,125 @@ from twisted.internet import defer -from . import V2AlphaRestTestCase +from tests import unittest from synapse.rest.client.v2_alpha import filter -from synapse.api.errors import StoreError, Codes +from synapse.api.errors import Codes +import synapse.types + +from synapse.types import UserID + +from ....utils import MockHttpResource, setup_test_homeserver + +PATH_PREFIX = "/_matrix/client/v2_alpha" + + +class FilterTestCase(unittest.TestCase): -class FilterTestCase(V2AlphaRestTestCase): USER_ID = "@apple:test" + EXAMPLE_FILTER = {"type": ["m.*"]} + EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}' TO_REGISTER = [filter] - def make_datastore_mock(self): - datastore = super(FilterTestCase, self).make_datastore_mock() + @defer.inlineCallbacks + def setUp(self): + self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) + + self.hs = yield setup_test_homeserver( + http_client=None, + resource_for_client=self.mock_resource, + resource_for_federation=self.mock_resource, + ) + + self.auth = self.hs.get_auth() - self._user_filters = {} + def get_user_by_access_token(token=None, allow_guest=False): + return { + "user": UserID.from_string(self.USER_ID), + "token_id": 1, + "is_guest": False, + } - def add_user_filter(user_localpart, definition): - filters = self._user_filters.setdefault(user_localpart, []) - filter_id = len(filters) - filters.append(definition) - return defer.succeed(filter_id) - datastore.add_user_filter = add_user_filter + def get_user_by_req(request, allow_guest=False, rights="access"): + return synapse.types.create_requester( + UserID.from_string(self.USER_ID), 1, False, None) - def get_user_filter(user_localpart, filter_id): - if user_localpart not in self._user_filters: - raise StoreError(404, "No user") - filters = self._user_filters[user_localpart] - if filter_id >= len(filters): - raise StoreError(404, "No filter") - return defer.succeed(filters[filter_id]) - datastore.get_user_filter = get_user_filter + self.auth.get_user_by_access_token = get_user_by_access_token + self.auth.get_user_by_req = get_user_by_req - return datastore + self.store = self.hs.get_datastore() + self.filtering = self.hs.get_filtering() + + for r in self.TO_REGISTER: + r.register_servlets(self.hs, self.mock_resource) @defer.inlineCallbacks def test_add_filter(self): (code, response) = yield self.mock_resource.trigger( - "POST", "/user/%s/filter" % (self.USER_ID), '{"type": ["m.*"]}' + "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON ) self.assertEquals(200, code) self.assertEquals({"filter_id": "0"}, response) + filter = yield self.store.get_user_filter( + user_localpart='apple', + filter_id=0, + ) + self.assertEquals(filter, self.EXAMPLE_FILTER) + + @defer.inlineCallbacks + def test_add_filter_for_other_user(self): + (code, response) = yield self.mock_resource.trigger( + "POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON + ) + self.assertEquals(403, code) + self.assertEquals(response['errcode'], Codes.FORBIDDEN) - self.assertIn("apple", self._user_filters) - self.assertEquals(len(self._user_filters["apple"]), 1) - self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0]) + @defer.inlineCallbacks + def test_add_filter_non_local_user(self): + _is_mine = self.hs.is_mine + self.hs.is_mine = lambda target_user: False + (code, response) = yield self.mock_resource.trigger( + "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON + ) + self.hs.is_mine = _is_mine + self.assertEquals(403, code) + self.assertEquals(response['errcode'], Codes.FORBIDDEN) @defer.inlineCallbacks def test_get_filter(self): - self._user_filters["apple"] = [ - {"type": ["m.*"]} - ] - + filter_id = yield self.filtering.add_user_filter( + user_localpart='apple', + user_filter=self.EXAMPLE_FILTER + ) (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/0" % (self.USER_ID) + "/user/%s/filter/%s" % (self.USER_ID, filter_id) ) self.assertEquals(200, code) - self.assertEquals({"type": ["m.*"]}, response) + self.assertEquals(self.EXAMPLE_FILTER, response) @defer.inlineCallbacks - def test_get_filter_no_id(self): - self._user_filters["apple"] = [ - {"type": ["m.*"]} - ] - + def test_get_filter_non_existant(self): (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/2" % (self.USER_ID) + "/user/%s/filter/12382148321" % (self.USER_ID) ) self.assertEquals(400, code) + self.assertEquals(response['errcode'], Codes.NOT_FOUND) + # Currently invalid params do not have an appropriate errcode + # in errors.py @defer.inlineCallbacks - def test_get_filter_no_user(self): + def test_get_filter_invalid_id(self): (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/0" % (self.USER_ID) + "/user/%s/filter/foobar" % (self.USER_ID) ) self.assertEquals(400, code) - self.assertEquals(response['errcode'], Codes.FORBIDDEN) + # No ID also returns an invalid_id error @defer.inlineCallbacks - def test_get_filter_missing_id(self): + def test_get_filter_no_id(self): (code, response) = yield self.mock_resource.trigger_get( - "/user/%s/filter/0" % (self.USER_ID) + "/user/%s/filter/" % (self.USER_ID) ) self.assertEquals(400, code) - self.assertEquals(response['errcode'], Codes.NOT_FOUND) -- cgit 1.4.1