From d79e90f078c83314de3dc469770750dd2585e255 Mon Sep 17 00:00:00 2001 From: David Baker Date: Tue, 22 Dec 2015 17:56:56 +0000 Subject: Add mocks to make tests work again --- tests/handlers/test_federation.py | 7 +++++++ tests/handlers/test_room.py | 9 +++++++++ 2 files changed, 16 insertions(+) (limited to 'tests') diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index d392c23015..a4758c03db 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -49,6 +49,10 @@ class FederationTestCase(unittest.TestCase): "get_destination_retry_timings", "set_destination_retry_timings", "have_events", + "get_users_in_room", + "bulk_get_push_rules", + "get_current_state", + "set_actions_for_event_and_users", ]), resource_for_federation=NonCallableMock(), http_client=NonCallableMock(spec_set=[]), @@ -85,6 +89,9 @@ class FederationTestCase(unittest.TestCase): self.datastore.persist_event.return_value = defer.succeed((1,1)) self.datastore.get_room.return_value = defer.succeed(True) + self.datastore.get_users_in_room.return_value = ["@a:b"] + self.datastore.bulk_get_push_rules.return_value = {} + self.datastore.get_current_state.return_value = {} self.auth.check_host_in_room.return_value = defer.succeed(True) retry_timings_res = { diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py index 2a7553f982..ba20b31945 100644 --- a/tests/handlers/test_room.py +++ b/tests/handlers/test_room.py @@ -43,6 +43,10 @@ class RoomMemberHandlerTestCase(unittest.TestCase): "store_room", "get_latest_events_in_room", "add_event_hashes", + "get_users_in_room", + "bulk_get_push_rules", + "get_current_state", + "set_actions_for_event_and_users", ]), resource_for_federation=NonCallableMock(), http_client=NonCallableMock(spec_set=[]), @@ -90,6 +94,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase): self.datastore.persist_event.return_value = (1,1) self.datastore.add_event_hashes.return_value = [] + self.datastore.get_users_in_room.return_value = ["@bob:red"] + self.datastore.bulk_get_push_rules.return_value = {} @defer.inlineCallbacks def test_invite(self): @@ -109,6 +115,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): self.datastore.get_latest_events_in_room.return_value = ( defer.succeed([]) ) + self.datastore.get_current_state.return_value = {} def annotate(_): ctx = Mock() @@ -190,6 +197,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): self.datastore.get_latest_events_in_room.return_value = ( defer.succeed([]) ) + self.datastore.get_current_state.return_value = {} def annotate(_): ctx = Mock() @@ -265,6 +273,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): self.datastore.get_latest_events_in_room.return_value = ( defer.succeed([]) ) + self.datastore.get_current_state.return_value = {} def annotate(_): ctx = Mock() -- cgit 1.5.1 From 92a1e74b202757b0f4b577ccbd3e31d8dd4d6460 Mon Sep 17 00:00:00 2001 From: David Baker Date: Mon, 4 Jan 2016 14:17:35 +0000 Subject: fix tests --- tests/handlers/test_federation.py | 2 +- tests/handlers/test_room.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index a4758c03db..6acc4ebadc 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -52,7 +52,7 @@ class FederationTestCase(unittest.TestCase): "get_users_in_room", "bulk_get_push_rules", "get_current_state", - "set_actions_for_event_and_users", + "set_push_actions_for_event_and_users", ]), resource_for_federation=NonCallableMock(), http_client=NonCallableMock(spec_set=[]), diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py index ba20b31945..ff2b597124 100644 --- a/tests/handlers/test_room.py +++ b/tests/handlers/test_room.py @@ -46,7 +46,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): "get_users_in_room", "bulk_get_push_rules", "get_current_state", - "set_actions_for_event_and_users", + "set_push_actions_for_event_and_users", ]), resource_for_federation=NonCallableMock(), http_client=NonCallableMock(spec_set=[]), -- cgit 1.5.1 From cfd07aafff71b452a01265f304172f56b2c49759 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 5 Jan 2016 18:01:18 +0000 Subject: Allow guests to upgrade their accounts --- synapse/api/auth.py | 6 ++-- synapse/handlers/auth.py | 6 ++-- synapse/handlers/register.py | 37 +++++++++++++++++------ synapse/handlers/room.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/rest/client/v2_alpha/register.py | 12 ++++++-- synapse/rest/media/v1/thumbnail_resource.py | 2 +- synapse/storage/prepare_database.py | 4 +-- synapse/storage/registration.py | 23 +++++++++----- synapse/storage/schema/delta/28/upgrade_times.sql | 21 +++++++++++++ tests/api/test_auth.py | 18 +++++------ 11 files changed, 93 insertions(+), 40 deletions(-) create mode 100644 synapse/storage/schema/delta/28/upgrade_times.sql (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index adb7d64482..b86c6c8399 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd +# Copyright 2014 - 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -583,7 +583,7 @@ class Auth(object): AuthError if no user by that token exists or the token is invalid. """ try: - ret = yield self._get_user_from_macaroon(token) + ret = yield self.get_user_from_macaroon(token) except AuthError: # TODO(daniel): Remove this fallback when all existing access tokens # have been re-issued as macaroons. @@ -591,7 +591,7 @@ class Auth(object): defer.returnValue(ret) @defer.inlineCallbacks - def _get_user_from_macaroon(self, macaroon_str): + def get_user_from_macaroon(self, macaroon_str): try: macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) self.validate_macaroon(macaroon, "access", False) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index e64b67cdfd..62e82a2570 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd +# Copyright 2014 - 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -408,7 +408,7 @@ class AuthHandler(BaseHandler): macaroon = pymacaroons.Macaroon.deserialize(login_token) auth_api = self.hs.get_auth() auth_api.validate_macaroon(macaroon, "login", True) - return self._get_user_from_macaroon(macaroon) + return self.get_user_from_macaroon(macaroon) except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) @@ -421,7 +421,7 @@ class AuthHandler(BaseHandler): macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon - def _get_user_from_macaroon(self, macaroon): + def get_user_from_macaroon(self, macaroon): user_prefix = "user_id = " for caveat in macaroon.caveats: if caveat.caveat_id.startswith(user_prefix): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index baf7c14e40..6f111ff63e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd +# Copyright 2014 - 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,12 +40,13 @@ class RegistrationHandler(BaseHandler): def __init__(self, hs): super(RegistrationHandler, self).__init__(hs) + self.auth = hs.get_auth() self.distributor = hs.get_distributor() self.distributor.declare("registered_user") self.captcha_client = CaptchaServerHttpClient(hs) @defer.inlineCallbacks - def check_username(self, localpart): + def check_username(self, localpart, guest_access_token=None): yield run_on_reactor() if urllib.quote(localpart) != localpart: @@ -62,14 +63,29 @@ class RegistrationHandler(BaseHandler): users = yield self.store.get_users_by_id_case_insensitive(user_id) if users: - raise SynapseError( - 400, - "User ID already taken.", - errcode=Codes.USER_IN_USE, - ) + if not guest_access_token: + raise SynapseError( + 400, + "User ID already taken.", + errcode=Codes.USER_IN_USE, + ) + user_data = yield self.auth.get_user_from_macaroon(guest_access_token) + if not user_data["is_guest"] or user_data["user"].localpart != localpart: + raise AuthError( + 403, + "Cannot register taken user ID without valid guest " + "credentials for that user.", + errcode=Codes.FORBIDDEN, + ) @defer.inlineCallbacks - def register(self, localpart=None, password=None, generate_token=True): + def register( + self, + localpart=None, + password=None, + generate_token=True, + guest_access_token=None + ): """Registers a new client on the server. Args: @@ -89,7 +105,7 @@ class RegistrationHandler(BaseHandler): password_hash = self.auth_handler().hash(password) if localpart: - yield self.check_username(localpart) + yield self.check_username(localpart, guest_access_token=guest_access_token) user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -100,7 +116,8 @@ class RegistrationHandler(BaseHandler): yield self.store.register( user_id=user_id, token=token, - password_hash=password_hash + password_hash=password_hash, + was_guest=guest_access_token is not None, ) yield registered_user(self.distributor, user) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 0cfeda10d8..6186c37c7c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd +# Copyright 2014 - 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9796f2a57f..41a42418a9 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd +# Copyright 2015 - 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b2b89652c6..25389ceded 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd +# Copyright 2015 - 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -119,8 +119,13 @@ class RegisterRestServlet(RestServlet): if self.hs.config.disable_registration: raise SynapseError(403, "Registration has been disabled") + guest_access_token = body.get("guest_access_token", None) + if desired_username is not None: - yield self.registration_handler.check_username(desired_username) + yield self.registration_handler.check_username( + desired_username, + guest_access_token=guest_access_token + ) if self.hs.config.enable_registration_captcha: flows = [ @@ -150,7 +155,8 @@ class RegisterRestServlet(RestServlet): (user_id, token) = yield self.registration_handler.register( localpart=desired_username, - password=new_password + password=new_password, + guest_access_token=guest_access_token, ) if result and LoginType.EMAIL_IDENTITY in result: diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 8b8fba3dc7..c18160534e 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd +# Copyright 2014 - 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 16eff62544..c1f5f99789 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd +# Copyright 2014 - 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 27 +SCHEMA_VERSION = 28 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 09a05b08ef..f0fa0bd33c 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd +# Copyright 2014 - 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -73,30 +73,39 @@ class RegistrationStore(SQLBaseStore): ) @defer.inlineCallbacks - def register(self, user_id, token, password_hash): + def register(self, user_id, token, password_hash, was_guest=False): """Attempts to register an account. Args: user_id (str): The desired user ID to register. token (str): The desired access token to use for this user. password_hash (str): Optional. The password hash for this user. + was_guest (bool): Optional. Whether this is a guest account being + upgraded to a non-guest account. Raises: StoreError if the user_id could not be registered. """ yield self.runInteraction( "register", - self._register, user_id, token, password_hash + self._register, user_id, token, password_hash, was_guest ) - def _register(self, txn, user_id, token, password_hash): + def _register(self, txn, user_id, token, password_hash, was_guest): now = int(self.clock.time()) next_id = self._access_tokens_id_gen.get_next_txn(txn) try: - txn.execute("INSERT INTO users(name, password_hash, creation_ts) " - "VALUES (?,?,?)", - [user_id, password_hash, now]) + if was_guest: + txn.execute("UPDATE users SET" + " password_hash = ?," + " upgrade_ts = ?" + " WHERE name = ?", + [password_hash, now, user_id]) + else: + txn.execute("INSERT INTO users(name, password_hash, creation_ts) " + "VALUES (?,?,?)", + [user_id, password_hash, now]) except self.database_engine.module.IntegrityError: raise StoreError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE diff --git a/synapse/storage/schema/delta/28/upgrade_times.sql b/synapse/storage/schema/delta/28/upgrade_times.sql new file mode 100644 index 0000000000..3e4a9ab455 --- /dev/null +++ b/synapse/storage/schema/delta/28/upgrade_times.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Stores the timestamp when a user upgraded from a guest to a full user, if + * that happened. + */ + +ALTER TABLE users ADD COLUMN upgrade_ts BIGINT; diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 70d928defe..5ff4c8a873 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd +# Copyright 2015 - 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -154,7 +154,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize()) + user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize()) user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) @@ -171,7 +171,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("guest = true") serialized = macaroon.serialize() - user_info = yield self.auth._get_user_from_macaroon(serialized) + user_info = yield self.auth.get_user_from_macaroon(serialized) user = user_info["user"] is_guest = user_info["is_guest"] self.assertEqual(UserID.from_string(user_id), user) @@ -192,7 +192,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user,)) with self.assertRaises(AuthError) as cm: - yield self.auth._get_user_from_macaroon(macaroon.serialize()) + yield self.auth.get_user_from_macaroon(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("User mismatch", cm.exception.msg) @@ -212,7 +212,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("type = access") with self.assertRaises(AuthError) as cm: - yield self.auth._get_user_from_macaroon(macaroon.serialize()) + yield self.auth.get_user_from_macaroon(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("No user caveat", cm.exception.msg) @@ -234,7 +234,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("user_id = %s" % (user,)) with self.assertRaises(AuthError) as cm: - yield self.auth._get_user_from_macaroon(macaroon.serialize()) + yield self.auth.get_user_from_macaroon(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("Invalid macaroon", cm.exception.msg) @@ -257,7 +257,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("cunning > fox") with self.assertRaises(AuthError) as cm: - yield self.auth._get_user_from_macaroon(macaroon.serialize()) + yield self.auth.get_user_from_macaroon(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("Invalid macaroon", cm.exception.msg) @@ -285,11 +285,11 @@ class AuthTestCase(unittest.TestCase): self.hs.clock.now = 5000 # seconds - yield self.auth._get_user_from_macaroon(macaroon.serialize()) + yield self.auth.get_user_from_macaroon(macaroon.serialize()) # TODO(daniel): Turn on the check that we validate expiration, when we # validate expiration (and remove the above line, which will start # throwing). # with self.assertRaises(AuthError) as cm: - # yield self.auth._get_user_from_macaroon(macaroon.serialize()) + # yield self.auth.get_user_from_macaroon(macaroon.serialize()) # self.assertEqual(401, cm.exception.code) # self.assertIn("Invalid macaroon", cm.exception.msg) -- cgit 1.5.1 From 0e48f7f2458f08341131b3b90c78b7034fe02d14 Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 6 Jan 2016 16:46:41 +0000 Subject: fix tests --- tests/handlers/test_federation.py | 4 ++++ tests/handlers/test_room.py | 5 +++++ tests/storage/test_registration.py | 2 +- 3 files changed, 10 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 6acc4ebadc..029c094115 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -53,6 +53,8 @@ class FederationTestCase(unittest.TestCase): "bulk_get_push_rules", "get_current_state", "set_push_actions_for_event_and_users", + "is_guest", + "get_state_for_events", ]), resource_for_federation=NonCallableMock(), http_client=NonCallableMock(spec_set=[]), @@ -73,6 +75,8 @@ class FederationTestCase(unittest.TestCase): self.handlers.federation_handler = FederationHandler(self.hs) + self.datastore.get_state_for_events.return_value = {"$a:b": {}} + @defer.inlineCallbacks def test_msg(self): pdu = FrozenEvent({ diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py index ff2b597124..b1c8e61522 100644 --- a/tests/handlers/test_room.py +++ b/tests/handlers/test_room.py @@ -47,6 +47,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase): "bulk_get_push_rules", "get_current_state", "set_push_actions_for_event_and_users", + "get_state_for_events", + "is_guest", ]), resource_for_federation=NonCallableMock(), http_client=NonCallableMock(spec_set=[]), @@ -116,6 +118,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): defer.succeed([]) ) self.datastore.get_current_state.return_value = {} + self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids} def annotate(_): ctx = Mock() @@ -198,6 +201,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): defer.succeed([]) ) self.datastore.get_current_state.return_value = {} + self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids} def annotate(_): ctx = Mock() @@ -274,6 +278,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): defer.succeed([]) ) self.datastore.get_current_state.return_value = {} + self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids} def annotate(_): ctx = Mock() diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 0cce6c37df..4760131f9c 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -45,7 +45,7 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertEquals( # TODO(paul): Surely this field should be 'user_id', not 'name' # Additionally surely it shouldn't come in a 1-element list - {"name": self.user_id, "password_hash": self.pwhash}, + {"name": self.user_id, "password_hash": self.pwhash, "is_guest": 0}, (yield self.store.get_user_by_id(self.user_id)) ) -- cgit 1.5.1 From 6c28ac260c2ce4bf93737e53ea3297bff08924c7 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Thu, 7 Jan 2016 04:26:29 +0000 Subject: copyrights --- contrib/cmdclient/console.py | 2 +- contrib/cmdclient/http.py | 2 +- contrib/experiments/cursesio.py | 2 +- contrib/experiments/test_messaging.py | 2 +- contrib/graph/graph.py | 2 +- contrib/graph/graph2.py | 2 +- scripts-dev/copyrighter-sql.pl | 4 ++-- scripts-dev/copyrighter.pl | 4 ++-- scripts/register_new_matrix_user | 2 +- scripts/synapse_port_db | 2 +- setup.py | 2 +- synapse/__init__.py | 2 +- synapse/api/__init__.py | 2 +- synapse/api/constants.py | 2 +- synapse/api/errors.py | 2 +- synapse/api/filtering.py | 2 +- synapse/api/ratelimiting.py | 2 +- synapse/api/urls.py | 2 +- synapse/app/__init__.py | 2 +- synapse/app/homeserver.py | 2 +- synapse/app/synctl.py | 2 +- synapse/appservice/__init__.py | 2 +- synapse/appservice/api.py | 2 +- synapse/appservice/scheduler.py | 2 +- synapse/config/__init__.py | 2 +- synapse/config/__main__.py | 2 +- synapse/config/_base.py | 2 +- synapse/config/appservice.py | 2 +- synapse/config/captcha.py | 2 +- synapse/config/cas.py | 2 +- synapse/config/database.py | 2 +- synapse/config/homeserver.py | 2 +- synapse/config/key.py | 2 +- synapse/config/logger.py | 2 +- synapse/config/metrics.py | 2 +- synapse/config/password.py | 2 +- synapse/config/ratelimiting.py | 2 +- synapse/config/registration.py | 2 +- synapse/config/server.py | 2 +- synapse/config/tls.py | 2 +- synapse/config/voip.py | 2 +- synapse/crypto/__init__.py | 2 +- synapse/crypto/context_factory.py | 2 +- synapse/crypto/event_signing.py | 2 +- synapse/crypto/keyclient.py | 2 +- synapse/crypto/keyring.py | 2 +- synapse/events/__init__.py | 2 +- synapse/events/builder.py | 2 +- synapse/events/snapshot.py | 2 +- synapse/events/utils.py | 2 +- synapse/events/validator.py | 2 +- synapse/federation/__init__.py | 2 +- synapse/federation/federation_base.py | 2 +- synapse/federation/federation_client.py | 2 +- synapse/federation/federation_server.py | 2 +- synapse/federation/persistence.py | 2 +- synapse/federation/replication.py | 2 +- synapse/federation/transaction_queue.py | 2 +- synapse/federation/transport/__init__.py | 2 +- synapse/federation/transport/client.py | 2 +- synapse/federation/transport/server.py | 2 +- synapse/federation/units.py | 2 +- synapse/handlers/__init__.py | 2 +- synapse/handlers/account_data.py | 2 +- synapse/handlers/admin.py | 2 +- synapse/handlers/appservice.py | 2 +- synapse/handlers/directory.py | 2 +- synapse/handlers/events.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/identity.py | 2 +- synapse/handlers/presence.py | 2 +- synapse/handlers/profile.py | 2 +- synapse/handlers/receipts.py | 2 +- synapse/handlers/search.py | 2 +- synapse/handlers/typing.py | 2 +- synapse/http/__init__.py | 2 +- synapse/http/client.py | 2 +- synapse/http/endpoint.py | 2 +- synapse/http/matrixfederationclient.py | 2 +- synapse/http/server.py | 2 +- synapse/http/servlet.py | 2 +- synapse/metrics/__init__.py | 2 +- synapse/metrics/metric.py | 2 +- synapse/metrics/resource.py | 2 +- synapse/push/__init__.py | 2 +- synapse/push/baserules.py | 2 +- synapse/push/httppusher.py | 2 +- synapse/push/push_rule_evaluator.py | 2 +- synapse/push/pusherpool.py | 2 +- synapse/push/rulekinds.py | 2 +- synapse/python_dependencies.py | 2 +- synapse/rest/__init__.py | 2 +- synapse/rest/client/__init__.py | 2 +- synapse/rest/client/v1/__init__.py | 2 +- synapse/rest/client/v1/admin.py | 2 +- synapse/rest/client/v1/base.py | 2 +- synapse/rest/client/v1/directory.py | 2 +- synapse/rest/client/v1/events.py | 2 +- synapse/rest/client/v1/initial_sync.py | 2 +- synapse/rest/client/v1/login.py | 2 +- synapse/rest/client/v1/presence.py | 2 +- synapse/rest/client/v1/profile.py | 2 +- synapse/rest/client/v1/push_rule.py | 2 +- synapse/rest/client/v1/pusher.py | 2 +- synapse/rest/client/v1/register.py | 2 +- synapse/rest/client/v1/room.py | 2 +- synapse/rest/client/v1/transactions.py | 2 +- synapse/rest/client/v1/voip.py | 2 +- synapse/rest/client/v2_alpha/__init__.py | 2 +- synapse/rest/client/v2_alpha/_base.py | 2 +- synapse/rest/client/v2_alpha/account.py | 2 +- synapse/rest/client/v2_alpha/account_data.py | 2 +- synapse/rest/client/v2_alpha/auth.py | 2 +- synapse/rest/client/v2_alpha/filter.py | 2 +- synapse/rest/client/v2_alpha/keys.py | 2 +- synapse/rest/client/v2_alpha/receipts.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/rest/client/v2_alpha/tags.py | 2 +- synapse/rest/client/v2_alpha/tokenrefresh.py | 2 +- synapse/rest/key/__init__.py | 2 +- synapse/rest/key/v1/__init__.py | 2 +- synapse/rest/key/v1/server_key_resource.py | 2 +- synapse/rest/key/v2/__init__.py | 2 +- synapse/rest/key/v2/local_key_resource.py | 2 +- synapse/rest/key/v2/remote_key_resource.py | 2 +- synapse/rest/media/v0/content_repository.py | 2 +- synapse/rest/media/v1/__init__.py | 2 +- synapse/rest/media/v1/base_resource.py | 2 +- synapse/rest/media/v1/download_resource.py | 2 +- synapse/rest/media/v1/filepath.py | 2 +- synapse/rest/media/v1/identicon_resource.py | 2 +- synapse/rest/media/v1/media_repository.py | 2 +- synapse/rest/media/v1/thumbnailer.py | 2 +- synapse/rest/media/v1/upload_resource.py | 2 +- synapse/server.py | 2 +- synapse/state.py | 2 +- synapse/storage/__init__.py | 2 +- synapse/storage/_base.py | 2 +- synapse/storage/account_data.py | 2 +- synapse/storage/appservice.py | 2 +- synapse/storage/background_updates.py | 2 +- synapse/storage/directory.py | 2 +- synapse/storage/end_to_end_keys.py | 2 +- synapse/storage/engines/__init__.py | 2 +- synapse/storage/engines/_base.py | 2 +- synapse/storage/engines/postgres.py | 2 +- synapse/storage/engines/sqlite3.py | 2 +- synapse/storage/event_federation.py | 2 +- synapse/storage/events.py | 2 +- synapse/storage/filtering.py | 2 +- synapse/storage/keys.py | 2 +- synapse/storage/media_repository.py | 2 +- synapse/storage/presence.py | 2 +- synapse/storage/profile.py | 2 +- synapse/storage/push_rule.py | 2 +- synapse/storage/pusher.py | 2 +- synapse/storage/receipts.py | 2 +- synapse/storage/rejections.py | 2 +- synapse/storage/room.py | 2 +- synapse/storage/roommember.py | 2 +- synapse/storage/schema/delta/11/v11.sql | 2 +- synapse/storage/schema/delta/12/v12.sql | 2 +- synapse/storage/schema/delta/13/v13.sql | 2 +- synapse/storage/schema/delta/14/upgrade_appservice_db.py | 2 +- synapse/storage/schema/delta/14/v14.sql | 2 +- synapse/storage/schema/delta/15/appservice_txns.sql | 2 +- synapse/storage/schema/delta/17/drop_indexes.sql | 2 +- synapse/storage/schema/delta/17/server_keys.sql | 2 +- synapse/storage/schema/delta/18/server_keys_bigger_ints.sql | 2 +- synapse/storage/schema/delta/19/event_index.sql | 2 +- synapse/storage/schema/delta/20/pushers.py | 2 +- synapse/storage/schema/delta/21/end_to_end_keys.sql | 2 +- synapse/storage/schema/delta/21/receipts.sql | 2 +- synapse/storage/schema/delta/22/receipts_index.sql | 2 +- synapse/storage/schema/delta/23/drop_state_index.sql | 2 +- synapse/storage/schema/delta/23/refresh_tokens.sql | 2 +- synapse/storage/schema/delta/24/stats_reporting.sql | 2 +- synapse/storage/schema/delta/25/00background_updates.sql | 2 +- synapse/storage/schema/delta/25/fts.py | 2 +- synapse/storage/schema/delta/25/guest_access.sql | 2 +- synapse/storage/schema/delta/25/history_visibility.sql | 2 +- synapse/storage/schema/delta/25/tags.sql | 2 +- synapse/storage/schema/delta/26/account_data.sql | 2 +- synapse/storage/schema/delta/27/account_data.sql | 2 +- synapse/storage/schema/delta/27/forgotten_memberships.sql | 2 +- synapse/storage/schema/delta/27/ts.py | 2 +- synapse/storage/schema/full_schemas/11/event_edges.sql | 2 +- synapse/storage/schema/full_schemas/11/event_signatures.sql | 2 +- synapse/storage/schema/full_schemas/11/im.sql | 2 +- synapse/storage/schema/full_schemas/11/keys.sql | 2 +- synapse/storage/schema/full_schemas/11/media_repository.sql | 2 +- synapse/storage/schema/full_schemas/11/presence.sql | 2 +- synapse/storage/schema/full_schemas/11/profiles.sql | 2 +- synapse/storage/schema/full_schemas/11/redactions.sql | 2 +- synapse/storage/schema/full_schemas/11/room_aliases.sql | 2 +- synapse/storage/schema/full_schemas/11/state.sql | 2 +- synapse/storage/schema/full_schemas/11/transactions.sql | 2 +- synapse/storage/schema/full_schemas/11/users.sql | 2 +- synapse/storage/schema/full_schemas/16/application_services.sql | 2 +- synapse/storage/schema/full_schemas/16/event_edges.sql | 2 +- synapse/storage/schema/full_schemas/16/event_signatures.sql | 2 +- synapse/storage/schema/full_schemas/16/im.sql | 2 +- synapse/storage/schema/full_schemas/16/keys.sql | 2 +- synapse/storage/schema/full_schemas/16/media_repository.sql | 2 +- synapse/storage/schema/full_schemas/16/presence.sql | 2 +- synapse/storage/schema/full_schemas/16/profiles.sql | 2 +- synapse/storage/schema/full_schemas/16/push.sql | 2 +- synapse/storage/schema/full_schemas/16/redactions.sql | 2 +- synapse/storage/schema/full_schemas/16/room_aliases.sql | 2 +- synapse/storage/schema/full_schemas/16/state.sql | 2 +- synapse/storage/schema/full_schemas/16/transactions.sql | 2 +- synapse/storage/schema/full_schemas/16/users.sql | 2 +- synapse/storage/schema/schema_version.sql | 2 +- synapse/storage/search.py | 2 +- synapse/storage/signatures.py | 2 +- synapse/storage/state.py | 2 +- synapse/storage/stream.py | 2 +- synapse/storage/tags.py | 2 +- synapse/storage/transactions.py | 2 +- synapse/storage/util/__init__.py | 2 +- synapse/storage/util/id_generators.py | 2 +- synapse/streams/__init__.py | 2 +- synapse/streams/config.py | 2 +- synapse/streams/events.py | 2 +- synapse/types.py | 2 +- synapse/util/__init__.py | 2 +- synapse/util/async.py | 2 +- synapse/util/caches/__init__.py | 2 +- synapse/util/caches/descriptors.py | 2 +- synapse/util/caches/dictionary_cache.py | 2 +- synapse/util/caches/expiringcache.py | 2 +- synapse/util/caches/lrucache.py | 2 +- synapse/util/caches/snapshot_cache.py | 2 +- synapse/util/debug.py | 2 +- synapse/util/distributor.py | 2 +- synapse/util/frozenutils.py | 2 +- synapse/util/jsonobject.py | 2 +- synapse/util/logcontext.py | 2 +- synapse/util/logutils.py | 2 +- synapse/util/ratelimitutils.py | 2 +- synapse/util/retryutils.py | 2 +- synapse/util/stringutils.py | 2 +- tests/__init__.py | 2 +- tests/api/test_filtering.py | 2 +- tests/appservice/__init__.py | 2 +- tests/appservice/test_appservice.py | 2 +- tests/appservice/test_scheduler.py | 2 +- tests/crypto/__init__.py | 2 +- tests/crypto/test_event_signing.py | 2 +- tests/events/test_utils.py | 2 +- tests/federation/test_federation.py | 2 +- tests/handlers/test_appservice.py | 2 +- tests/handlers/test_auth.py | 2 +- tests/handlers/test_directory.py | 2 +- tests/handlers/test_federation.py | 2 +- tests/handlers/test_presence.py | 2 +- tests/handlers/test_presencelike.py | 2 +- tests/handlers/test_profile.py | 2 +- tests/handlers/test_room.py | 2 +- tests/handlers/test_typing.py | 2 +- tests/metrics/test_metric.py | 2 +- tests/rest/__init__.py | 2 +- tests/rest/client/__init__.py | 2 +- tests/rest/client/v1/__init__.py | 2 +- tests/rest/client/v1/test_events.py | 2 +- tests/rest/client/v1/test_presence.py | 2 +- tests/rest/client/v1/test_profile.py | 2 +- tests/rest/client/v1/test_rooms.py | 2 +- tests/rest/client/v1/test_typing.py | 2 +- tests/rest/client/v1/utils.py | 2 +- tests/rest/client/v2_alpha/__init__.py | 2 +- tests/rest/client/v2_alpha/test_filter.py | 2 +- tests/storage/event_injector.py | 2 +- tests/storage/test__base.py | 2 +- tests/storage/test_appservice.py | 2 +- tests/storage/test_base.py | 2 +- tests/storage/test_directory.py | 2 +- tests/storage/test_events.py | 2 +- tests/storage/test_presence.py | 2 +- tests/storage/test_profile.py | 2 +- tests/storage/test_redaction.py | 2 +- tests/storage/test_registration.py | 2 +- tests/storage/test_room.py | 2 +- tests/storage/test_roommember.py | 2 +- tests/storage/test_stream.py | 2 +- tests/test_distributor.py | 2 +- tests/test_state.py | 2 +- tests/test_test_utils.py | 2 +- tests/test_types.py | 2 +- tests/unittest.py | 2 +- tests/util/__init__.py | 2 +- tests/util/test_dict_cache.py | 2 +- tests/util/test_lrucache.py | 2 +- tests/util/test_snapshot_cache.py | 2 +- tests/utils.py | 2 +- 295 files changed, 297 insertions(+), 297 deletions(-) (limited to 'tests') diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index d9c6ec6a70..8bb03ce66a 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2014 OpenMarket Ltd +# Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py index 869f782ec1..4186897316 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014 OpenMarket Ltd +# Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/experiments/cursesio.py b/contrib/experiments/cursesio.py index 95d87a1fda..44afe81008 100644 --- a/contrib/experiments/cursesio.py +++ b/contrib/experiments/cursesio.py @@ -1,4 +1,4 @@ -# Copyright 2014 OpenMarket Ltd +# Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py index fedf786cec..85c9c11984 100644 --- a/contrib/experiments/test_messaging.py +++ b/contrib/experiments/test_messaging.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014 OpenMarket Ltd +# Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/graph/graph.py b/contrib/graph/graph.py index b2acadcf5e..afd1d446b4 100644 --- a/contrib/graph/graph.py +++ b/contrib/graph/graph.py @@ -1,4 +1,4 @@ -# Copyright 2014 OpenMarket Ltd +# Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/contrib/graph/graph2.py b/contrib/graph/graph2.py index d0d2cfe7c0..1ccad65728 100644 --- a/contrib/graph/graph2.py +++ b/contrib/graph/graph2.py @@ -1,4 +1,4 @@ -# Copyright 2014 OpenMarket Ltd +# Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/scripts-dev/copyrighter-sql.pl b/scripts-dev/copyrighter-sql.pl index 890e51e587..13e630fc11 100755 --- a/scripts-dev/copyrighter-sql.pl +++ b/scripts-dev/copyrighter-sql.pl @@ -1,5 +1,5 @@ #!/usr/bin/perl -pi -# Copyright 2015 OpenMarket Ltd +# Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ # limitations under the License. $copyright = < Date: Mon, 11 Jan 2016 15:29:57 +0000 Subject: Introduce a Requester object This tracks data about the entity which made the request. This is instead of passing around a tuple, which requires call-site modifications every time a new piece of optional context is passed around. I tried to introduce a User object. I gave up. --- synapse/api/auth.py | 8 +- synapse/rest/client/v1/admin.py | 5 +- synapse/rest/client/v1/directory.py | 8 +- synapse/rest/client/v1/events.py | 19 +++-- synapse/rest/client/v1/initial_sync.py | 4 +- synapse/rest/client/v1/presence.py | 16 ++-- synapse/rest/client/v1/profile.py | 8 +- synapse/rest/client/v1/push_rule.py | 13 +-- synapse/rest/client/v1/pusher.py | 5 +- synapse/rest/client/v1/room.py | 116 ++++++++++++++++----------- synapse/rest/client/v1/voip.py | 4 +- synapse/rest/client/v2_alpha/account.py | 20 ++--- synapse/rest/client/v2_alpha/account_data.py | 8 +- synapse/rest/client/v2_alpha/filter.py | 8 +- synapse/rest/client/v2_alpha/keys.py | 16 ++-- synapse/rest/client/v2_alpha/receipts.py | 4 +- synapse/rest/client/v2_alpha/sync.py | 11 +-- synapse/rest/client/v2_alpha/tags.py | 12 +-- synapse/rest/media/v0/content_repository.py | 6 +- synapse/rest/media/v1/upload_resource.py | 4 +- synapse/types.py | 3 + tests/api/test_auth.py | 12 +-- tests/rest/client/v1/test_presence.py | 5 +- tests/rest/client/v1/test_profile.py | 7 +- 24 files changed, 178 insertions(+), 144 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index b86c6c8399..876869bb74 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -22,7 +22,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError -from synapse.types import RoomID, UserID, EventID +from synapse.types import Requester, RoomID, UserID, EventID from synapse.util.logutils import log_function from unpaddedbase64 import decode_base64 @@ -534,7 +534,9 @@ class Auth(object): request.authenticated_entity = user_id - defer.returnValue((UserID.from_string(user_id), "", False)) + defer.returnValue( + Requester(UserID.from_string(user_id), "", False) + ) return except KeyError: pass # normal users won't have the user_id query parameter set. @@ -564,7 +566,7 @@ class Auth(object): request.authenticated_entity = user.to_string() - defer.returnValue((user, token_id, is_guest,)) + defer.returnValue(Requester(user, token_id, is_guest)) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 4d724dce72..e2f5eb7b29 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -31,8 +31,9 @@ class WhoisRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, _, _ = yield self.auth.get_user_by_req(request) - is_admin = yield self.auth.is_server_admin(auth_user) + requester = yield self.auth.get_user_by_req(request) + auth_user = requester.user + is_admin = yield self.auth.is_server_admin(requester.user) if not is_admin and target_user != auth_user: raise AuthError(403, "You are not a server admin") diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 7eef6bf5dc..74ec1e50e0 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -69,9 +69,9 @@ class ClientDirectoryServer(ClientV1RestServlet): try: # try to auth as a user - user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) try: - user_id = user.to_string() + user_id = requester.user.to_string() yield dir_handler.create_association( user_id, room_alias, room_id, servers ) @@ -116,8 +116,8 @@ class ClientDirectoryServer(ClientV1RestServlet): # fallback to default user behaviour if they aren't an AS pass - user, _, _ = yield self.auth.get_user_by_req(request) - + requester = yield self.auth.get_user_by_req(request) + user = requester.user is_admin = yield self.auth.is_server_admin(user) if not is_admin: raise AuthError(403, "You need to be a server admin") diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 631f2ca052..e89118b37d 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -34,10 +34,11 @@ class EventStreamRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, _, is_guest = yield self.auth.get_user_by_req( + requester = yield self.auth.get_user_by_req( request, - allow_guest=True + allow_guest=True, ) + is_guest = requester.is_guest room_id = None if is_guest: if "room_id" not in request.args: @@ -56,9 +57,13 @@ class EventStreamRestServlet(ClientV1RestServlet): as_client_event = "raw" not in request.args chunk = yield handler.get_stream( - auth_user.to_string(), pagin_config, timeout=timeout, - as_client_event=as_client_event, affect_presence=(not is_guest), - room_id=room_id, is_guest=is_guest + requester.user.to_string(), + pagin_config, + timeout=timeout, + as_client_event=as_client_event, + affect_presence=(not is_guest), + room_id=room_id, + is_guest=is_guest, ) except: logger.exception("Event stream failed") @@ -80,9 +85,9 @@ class EventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, event_id): - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) handler = self.handlers.event_handler - event = yield handler.get_event(auth_user, event_id) + event = yield handler.get_event(requester.user, event_id) time_now = self.clock.time_msec() if event: diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 541319c351..ad161bdbab 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -25,13 +25,13 @@ class InitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) handler = self.handlers.message_handler include_archived = request.args.get("archived", None) == ["true"] content = yield handler.snapshot_all_rooms( - user_id=user.to_string(), + user_id=requester.user.to_string(), pagin_config=pagination_config, as_client_event=as_client_event, include_archived=include_archived, diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 855385ec16..a6f8754e32 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -32,17 +32,17 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = yield self.handlers.presence_handler.get_state( - target_user=user, auth_user=auth_user) + target_user=user, auth_user=requester.user) defer.returnValue((200, state)) @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = {} @@ -64,7 +64,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): raise SynapseError(400, "Unable to parse state") yield self.handlers.presence_handler.set_state( - target_user=user, auth_user=auth_user, state=state) + target_user=user, auth_user=requester.user, state=state) defer.returnValue((200, {})) @@ -77,13 +77,13 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): raise SynapseError(400, "User not hosted on this Home Server") - if auth_user != user: + if requester.user != user: raise SynapseError(400, "Cannot get another user's presence list") presence = yield self.handlers.presence_handler.get_presence_list( @@ -97,13 +97,13 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): raise SynapseError(400, "User not hosted on this Home Server") - if auth_user != user: + if requester.user != user: raise SynapseError( 400, "Cannot modify another user's presence list") diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index d4bc9e076c..b15defdd07 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, _, _ = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) try: @@ -47,7 +47,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): defer.returnValue((400, "Unable to parse name")) yield self.handlers.profile_handler.set_displayname( - user, auth_user, new_name) + user, requester.user, new_name) defer.returnValue((200, {})) @@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: @@ -80,7 +80,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): defer.returnValue((400, "Unable to parse name")) yield self.handlers.profile_handler.set_avatar_url( - user, auth_user, new_name) + user, requester.user, new_name) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 2aab28ae7b..c0a21c0c12 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -43,7 +43,7 @@ class PushRuleRestServlet(ClientV1RestServlet): except InvalidRuleException as e: raise SynapseError(400, e.message) - user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) if '/' in spec['rule_id'] or '\\' in spec['rule_id']: raise SynapseError(400, "rule_id may not contain slashes") @@ -51,7 +51,7 @@ class PushRuleRestServlet(ClientV1RestServlet): content = _parse_json(request) if 'attr' in spec: - self.set_rule_attr(user.to_string(), spec, content) + self.set_rule_attr(requester.user, spec, content) defer.returnValue((200, {})) try: @@ -73,7 +73,7 @@ class PushRuleRestServlet(ClientV1RestServlet): try: yield self.hs.get_datastore().add_push_rule( - user_name=user.to_string(), + user_name=requester.user.to_string(), rule_id=_namespaced_rule_id_from_spec(spec), priority_class=priority_class, conditions=conditions, @@ -92,13 +92,13 @@ class PushRuleRestServlet(ClientV1RestServlet): def on_DELETE(self, request): spec = _rule_spec_from_path(request.postpath) - user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) namespaced_rule_id = _namespaced_rule_id_from_spec(spec) try: yield self.hs.get_datastore().delete_push_rule( - user.to_string(), namespaced_rule_id + requester.user.to_string(), namespaced_rule_id ) defer.returnValue((200, {})) except StoreError as e: @@ -109,7 +109,8 @@ class PushRuleRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) + user = requester.user # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 81a8786aeb..b162b210bc 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -30,7 +30,8 @@ class PusherRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - user, token_id, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) + user = requester.user content = _parse_json(request) @@ -71,7 +72,7 @@ class PusherRestServlet(ClientV1RestServlet): try: yield pusher_pool.add_pusher( user_name=user.to_string(), - access_token=token_id, + access_token=requester.access_token_id, profile_tag=content['profile_tag'], kind=content['kind'], app_id=content['app_id'], diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 926f77d1c3..7496b26735 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -61,10 +61,14 @@ class RoomCreateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) room_config = self.get_room_config(request) - info = yield self.make_room(room_config, auth_user, None) + info = yield self.make_room( + room_config, + requester.user, + None, + ) room_config.update(info) defer.returnValue((200, info)) @@ -124,15 +128,15 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): - user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( - user_id=user.to_string(), + user_id=requester.user.to_string(), room_id=room_id, event_type=event_type, state_key=state_key, - is_guest=is_guest, + is_guest=requester.is_guest, ) if not data: @@ -143,7 +147,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - user, token_id, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -151,7 +155,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): "type": event_type, "content": content, "room_id": room_id, - "sender": user.to_string(), + "sender": requester.user.to_string(), } if state_key is not None: @@ -159,7 +163,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): msg_handler = self.handlers.message_handler yield msg_handler.create_and_send_event( - event_dict, token_id=token_id, txn_id=txn_id, + event_dict, token_id=requester.access_token_id, txn_id=txn_id, ) defer.returnValue((200, {})) @@ -175,7 +179,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_type, txn_id=None): - user, token_id, _ = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -184,9 +188,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet): "type": event_type, "content": content, "room_id": room_id, - "sender": user.to_string(), + "sender": requester.user.to_string(), }, - token_id=token_id, + token_id=requester.access_token_id, txn_id=txn_id, ) @@ -220,9 +224,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): - user, token_id, is_guest = yield self.auth.get_user_by_req( + requester = yield self.auth.get_user_by_req( request, - allow_guest=True + allow_guest=True, ) # the identifier could be a room alias or a room id. Try one then the @@ -241,24 +245,27 @@ class JoinRoomAliasServlet(ClientV1RestServlet): if is_room_alias: handler = self.handlers.room_member_handler - ret_dict = yield handler.join_room_alias(user, identifier) + ret_dict = yield handler.join_room_alias( + requester.user, + identifier, + ) defer.returnValue((200, ret_dict)) else: # room id msg_handler = self.handlers.message_handler content = {"membership": Membership.JOIN} - if is_guest: + if requester.is_guest: content["kind"] = "guest" yield msg_handler.create_and_send_event( { "type": EventTypes.Member, "content": content, "room_id": identifier.to_string(), - "sender": user.to_string(), - "state_key": user.to_string(), + "sender": requester.user.to_string(), + "state_key": requester.user.to_string(), }, - token_id=token_id, + token_id=requester.access_token_id, txn_id=txn_id, - is_guest=is_guest, + is_guest=requester.is_guest, ) defer.returnValue((200, {"room_id": identifier.to_string()})) @@ -296,11 +303,11 @@ class RoomMemberListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) handler = self.handlers.message_handler events = yield handler.get_state_events( room_id=room_id, - user_id=user.to_string(), + user_id=requester.user.to_string(), ) chunk = [] @@ -315,7 +322,8 @@ class RoomMemberListRestServlet(ClientV1RestServlet): try: presence_handler = self.handlers.presence_handler presence_state = yield presence_handler.get_state( - target_user=target_user, auth_user=user + target_user=target_user, + auth_user=requester.user, ) event["content"].update(presence_state) except: @@ -332,7 +340,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request( request, default_limit=10, ) @@ -340,8 +348,8 @@ class RoomMessageListRestServlet(ClientV1RestServlet): handler = self.handlers.message_handler msgs = yield handler.get_messages( room_id=room_id, - user_id=user.to_string(), - is_guest=is_guest, + user_id=requester.user.to_string(), + is_guest=requester.is_guest, pagin_config=pagination_config, as_client_event=as_client_event ) @@ -355,13 +363,13 @@ class RoomStateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) handler = self.handlers.message_handler # Get all the current state for this room events = yield handler.get_state_events( room_id=room_id, - user_id=user.to_string(), - is_guest=is_guest, + user_id=requester.user.to_string(), + is_guest=requester.is_guest, ) defer.returnValue((200, events)) @@ -372,13 +380,13 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request) content = yield self.handlers.message_handler.room_initial_sync( room_id=room_id, - user_id=user.to_string(), + user_id=requester.user.to_string(), pagin_config=pagination_config, - is_guest=is_guest, + is_guest=requester.is_guest, ) defer.returnValue((200, content)) @@ -394,12 +402,16 @@ class RoomEventContext(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): - user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) limit = int(request.args.get("limit", [10])[0]) results = yield self.handlers.room_context_handler.get_event_context( - user, room_id, event_id, limit, is_guest + requester.user, + room_id, + event_id, + limit, + requester.is_guest, ) time_now = self.clock.time_msec() @@ -429,14 +441,18 @@ class RoomMembershipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, membership_action, txn_id=None): - user, token_id, is_guest = yield self.auth.get_user_by_req( + requester = yield self.auth.get_user_by_req( request, - allow_guest=True + allow_guest=True, ) + user = requester.user effective_membership_action = membership_action - if is_guest and membership_action not in {Membership.JOIN, Membership.LEAVE}: + if requester.is_guest and membership_action not in { + Membership.JOIN, + Membership.LEAVE + }: raise AuthError(403, "Guest access not allowed") content = _parse_json(request) @@ -451,7 +467,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): content["medium"], content["address"], content["id_server"], - token_id, + requester.access_token_id, txn_id ) defer.returnValue((200, {})) @@ -473,7 +489,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): msg_handler = self.handlers.message_handler content = {"membership": unicode(effective_membership_action)} - if is_guest: + if requester.is_guest: content["kind"] = "guest" yield msg_handler.create_and_send_event( @@ -484,9 +500,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet): "sender": user.to_string(), "state_key": state_key, }, - token_id=token_id, + token_id=requester.access_token_id, txn_id=txn_id, - is_guest=is_guest, + is_guest=requester.is_guest, ) if membership_action == "forget": @@ -524,7 +540,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_id, txn_id=None): - user, token_id, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -533,10 +549,10 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): "type": EventTypes.Redaction, "content": content, "room_id": room_id, - "sender": user.to_string(), + "sender": requester.user.to_string(), "redacts": event_id, }, - token_id=token_id, + token_id=requester.access_token_id, txn_id=txn_id, ) @@ -564,7 +580,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) room_id = urllib.unquote(room_id) target_user = UserID.from_string(urllib.unquote(user_id)) @@ -576,14 +592,14 @@ class RoomTypingRestServlet(ClientV1RestServlet): if content["typing"]: yield typing_handler.started_typing( target_user=target_user, - auth_user=auth_user, + auth_user=requester.user, room_id=room_id, timeout=content.get("timeout", 30000), ) else: yield typing_handler.stopped_typing( target_user=target_user, - auth_user=auth_user, + auth_user=requester.user, room_id=room_id, ) @@ -597,12 +613,16 @@ class SearchRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) content = _parse_json(request) batch = request.args.get("next_batch", [None])[0] - results = yield self.handlers.search_handler.search(auth_user, content, batch) + results = yield self.handlers.search_handler.search( + requester.user, + content, + batch, + ) defer.returnValue((200, results)) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 860cb0a642..ec4cf8db79 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret @@ -37,7 +37,7 @@ class VoipRestServlet(ClientV1RestServlet): defer.returnValue((200, {})) expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 - username = "%d:%s" % (expiry, auth_user.to_string()) + username = "%d:%s" % (expiry, requester.user.to_string()) mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1) # We need to use standard padded base64 encoding here diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index ddb6f041cd..fa56249a69 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -55,10 +55,11 @@ class PasswordRestServlet(RestServlet): if LoginType.PASSWORD in result: # if using password, they should also be logged in - auth_user, _, _ = yield self.auth.get_user_by_req(request) - if auth_user.to_string() != result[LoginType.PASSWORD]: + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + if requester_user_id.to_string() != result[LoginType.PASSWORD]: raise LoginError(400, "", Codes.UNKNOWN) - user_id = auth_user.to_string() + user_id = requester_user_id elif LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] if 'medium' not in threepid or 'address' not in threepid: @@ -102,10 +103,10 @@ class ThreepidRestServlet(RestServlet): def on_GET(self, request): yield run_on_reactor() - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) threepids = yield self.hs.get_datastore().user_get_threepids( - auth_user.to_string() + requester.user.to_string() ) defer.returnValue((200, {'threepids': threepids})) @@ -120,7 +121,8 @@ class ThreepidRestServlet(RestServlet): raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) threePidCreds = body['threePidCreds'] - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) @@ -135,7 +137,7 @@ class ThreepidRestServlet(RestServlet): raise SynapseError(500, "Invalid response from ID Server") yield self.auth_handler.add_threepid( - auth_user.to_string(), + user_id, threepid['medium'], threepid['address'], threepid['validated_at'], @@ -144,10 +146,10 @@ class ThreepidRestServlet(RestServlet): if 'bind' in body and body['bind']: logger.debug( "Binding emails %s to %s", - threepid, auth_user.to_string() + threepid, user_id ) yield self.identity_handler.bind_threepid( - threePidCreds, auth_user.to_string() + threePidCreds, user_id ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 629b04fe7a..985efe2a62 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -43,8 +43,8 @@ class AccountDataServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id, account_data_type): - auth_user, _, _ = yield self.auth.get_user_by_req(request) - if user_id != auth_user.to_string(): + requester = yield self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") try: @@ -82,8 +82,8 @@ class RoomAccountDataServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id, room_id, account_data_type): - auth_user, _, _ = yield self.auth.get_user_by_req(request) - if user_id != auth_user.to_string(): + requester = yield self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") try: diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 2af7bfaf99..7695bebc28 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -40,9 +40,9 @@ class GetFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, filter_id): target_user = UserID.from_string(user_id) - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) - if target_user != auth_user: + if target_user != requester.user: raise AuthError(403, "Cannot get filters for other users") if not self.hs.is_mine(target_user): @@ -76,9 +76,9 @@ class CreateFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) - if target_user != auth_user: + if target_user != requester.user: raise AuthError(403, "Cannot create filters for other users") if not self.hs.is_mine(target_user): diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 24c3554831..f989b08614 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -64,8 +64,8 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, device_id): - auth_user, _, _ = yield self.auth.get_user_by_req(request) - user_id = auth_user.to_string() + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() # TODO: Check that the device_id matches that in the authentication # or derive the device_id from the authentication instead. try: @@ -78,8 +78,8 @@ class KeyUploadServlet(RestServlet): device_keys = body.get("device_keys", None) if device_keys: logger.info( - "Updating device_keys for device %r for user %r at %d", - device_id, auth_user, time_now + "Updating device_keys for device %r for user %s at %d", + device_id, user_id, time_now ) # TODO: Sign the JSON with the server key yield self.store.set_e2e_device_keys( @@ -109,8 +109,8 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, device_id): - auth_user, _, _ = yield self.auth.get_user_by_req(request) - user_id = auth_user.to_string() + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() result = yield self.store.count_e2e_one_time_keys(user_id, device_id) defer.returnValue((200, {"one_time_key_counts": result})) @@ -182,8 +182,8 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): - auth_user, _, _ = yield self.auth.get_user_by_req(request) - auth_user_id = auth_user.to_string() + requester = yield self.auth.get_user_by_req(request) + auth_user_id = requester.user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] result = yield self.handle_request( diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 43c23d6090..eb4b369a3d 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -40,7 +40,7 @@ class ReceiptRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, receipt_type, event_id): - user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") @@ -48,7 +48,7 @@ class ReceiptRestServlet(RestServlet): yield self.receipts_handler.received_client_receipt( room_id, receipt_type, - user_id=user.to_string(), + user_id=requester.user.to_string(), event_id=event_id ) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index c05e7d50c8..3867547ade 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -85,9 +85,10 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, token_id, is_guest = yield self.auth.get_user_by_req( + requester = yield self.auth.get_user_by_req( request, allow_guest=True ) + user = requester.user timeout = parse_integer(request, "timeout", default=0) since = parse_string(request, "since") @@ -123,7 +124,7 @@ class SyncRestServlet(RestServlet): sync_config = SyncConfig( user=user, filter=filter, - is_guest=is_guest, + is_guest=requester.is_guest, ) if since is not None: @@ -146,15 +147,15 @@ class SyncRestServlet(RestServlet): time_now = self.clock.time_msec() joined = self.encode_joined( - sync_result.joined, filter, time_now, token_id + sync_result.joined, filter, time_now, requester.access_token_id ) invited = self.encode_invited( - sync_result.invited, filter, time_now, token_id + sync_result.invited, filter, time_now, requester.access_token_id ) archived = self.encode_archived( - sync_result.archived, filter, time_now, token_id + sync_result.archived, filter, time_now, requester.access_token_id ) response_content = { diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 1bfc36ab2b..42f2203f3d 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -42,8 +42,8 @@ class TagListServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, room_id): - auth_user, _, _ = yield self.auth.get_user_by_req(request) - if user_id != auth_user.to_string(): + requester = yield self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get tags for other users.") tags = yield self.store.get_tags_for_room(user_id, room_id) @@ -68,8 +68,8 @@ class TagServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id, room_id, tag): - auth_user, _, _ = yield self.auth.get_user_by_req(request) - if user_id != auth_user.to_string(): + requester = yield self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") try: @@ -88,8 +88,8 @@ class TagServlet(RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, user_id, room_id, tag): - auth_user, _, _ = yield self.auth.get_user_by_req(request) - if user_id != auth_user.to_string(): + requester = yield self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index dd7a1b2b31..dcf3eaee1f 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -66,11 +66,11 @@ class ContentRepoResource(resource.Resource): @defer.inlineCallbacks def map_request_to_name(self, request): # auth the user - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) # namespace all file uploads on the user prefix = base64.urlsafe_b64encode( - auth_user.to_string() + requester.user.to_string() ).replace('=', '') # use a random string for the main portion @@ -94,7 +94,7 @@ class ContentRepoResource(resource.Resource): file_name = prefix + main_part + suffix file_path = os.path.join(self.directory, file_name) logger.info("User %s is uploading a file to path %s", - auth_user.to_string(), + request.user.user_id.to_string(), file_path) # keep trying to make a non-clashing file, with a sensible max attempts diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index c1e895ee81..9c7ad4ae85 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource): @request_handler @defer.inlineCallbacks def _async_render_POST(self, request): - auth_user, _, _ = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point content_length = request.getHeader("Content-Length") @@ -110,7 +110,7 @@ class UploadResource(BaseMediaResource): content_uri = yield self.create_content( media_type, upload_name, request.content.read(), - content_length, auth_user + content_length, requester.user ) respond_with_json( diff --git a/synapse/types.py b/synapse/types.py index 1ec7b3e103..2095837ba6 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -18,6 +18,9 @@ from synapse.api.errors import SynapseError from collections import namedtuple +Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"]) + + class DomainSpecificString( namedtuple("DomainSpecificString", ("localpart", "domain")) ): diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 5ff4c8a873..474c5c418f 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -51,8 +51,8 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, _, _) = yield self.auth.get_user_by_req(request) - self.assertEquals(user.to_string(), self.test_user) + requester = yield self.auth.get_user_by_req(request) + self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) @@ -86,8 +86,8 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, _, _) = yield self.auth.get_user_by_req(request) - self.assertEquals(user.to_string(), self.test_user) + requester = yield self.auth.get_user_by_req(request) + self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) @@ -121,8 +121,8 @@ class AuthTestCase(unittest.TestCase): request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, _, _) = yield self.auth.get_user_by_req(request) - self.assertEquals(user.to_string(), masquerading_user_id) + requester = yield self.auth.get_user_by_req(request) + self.assertEquals(requester.user.to_string(), masquerading_user_id) def test_get_user_by_req_appservice_valid_token_bad_user_id(self): masquerading_user_id = "@doppelganger:matrix.org" diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index d782eadb6a..90b911f879 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -14,7 +14,6 @@ # limitations under the License. """Tests REST events for /presence paths.""" - from tests import unittest from twisted.internet import defer @@ -26,7 +25,7 @@ from synapse.api.constants import PresenceState from synapse.handlers.presence import PresenceHandler from synapse.rest.client.v1 import presence from synapse.rest.client.v1 import events -from synapse.types import UserID +from synapse.types import Requester, UserID from synapse.util.async import run_on_reactor from collections import namedtuple @@ -301,7 +300,7 @@ class PresenceEventStreamTestCase(unittest.TestCase): hs.get_clock().time_msec.return_value = 1000000 def _get_user_by_req(req=None, allow_guest=False): - return (UserID.from_string(myid), "", False) + return Requester(UserID.from_string(myid), "", False) hs.get_v1auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 77b7b06c10..c1a3f52043 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -14,16 +14,15 @@ # limitations under the License. """Tests REST events for /profile paths.""" - from tests import unittest from twisted.internet import defer -from mock import Mock, NonCallableMock +from mock import Mock from ....utils import MockHttpResource, setup_test_homeserver from synapse.api.errors import SynapseError, AuthError -from synapse.types import UserID +from synapse.types import Requester, UserID from synapse.rest.client.v1 import profile @@ -53,7 +52,7 @@ class ProfileTestCase(unittest.TestCase): ) def _get_user_by_req(request=None, allow_guest=False): - return (UserID.from_string(myid), "", False) + return Requester(UserID.from_string(myid), "", False) hs.get_v1auth().get_user_by_req = _get_user_by_req -- cgit 1.5.1 From c0a279e808435d286ae254f51253d2adb3ee7858 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 13 Jan 2016 11:15:20 +0000 Subject: Delete the table objects from TransactionStore --- synapse/storage/transactions.py | 68 ++++++----------------------------------- tests/handlers/test_presence.py | 1 - tests/handlers/test_typing.py | 1 - 3 files changed, 10 insertions(+), 60 deletions(-) (limited to 'tests') diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index b40a070b69..4475c451c1 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -16,8 +16,6 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached -from collections import namedtuple - from canonicaljson import encode_canonical_json import logging @@ -50,12 +48,15 @@ class TransactionStore(SQLBaseStore): def _get_received_txn_response(self, txn, transaction_id, origin): result = self._simple_select_one_txn( txn, - table=ReceivedTransactionsTable.table_name, + table="received_transactions", keyvalues={ "transaction_id": transaction_id, "origin": origin, }, - retcols=ReceivedTransactionsTable.fields, + retcols=( + "transaction_id", "origin", "ts", "response_code", "response_json", + "has_been_referenced", + ), allow_none=True, ) @@ -79,7 +80,7 @@ class TransactionStore(SQLBaseStore): """ return self._simple_insert( - table=ReceivedTransactionsTable.table_name, + table="received_transactions", values={ "transaction_id": transaction_id, "origin": origin, @@ -136,7 +137,7 @@ class TransactionStore(SQLBaseStore): self._simple_insert_txn( txn, - table=SentTransactions.table_name, + table="sent_transactions", values={ "id": next_id, "transaction_id": transaction_id, @@ -171,7 +172,7 @@ class TransactionStore(SQLBaseStore): code, response_json): self._simple_update_one_txn( txn, - table=SentTransactions.table_name, + table="sent_transactions", keyvalues={ "transaction_id": transaction_id, "destination": destination, @@ -229,11 +230,11 @@ class TransactionStore(SQLBaseStore): def _get_destination_retry_timings(self, txn, destination): result = self._simple_select_one_txn( txn, - table=DestinationsTable.table_name, + table="destinations", keyvalues={ "destination": destination, }, - retcols=DestinationsTable.fields, + retcols=("destination", "retry_last_ts", "retry_interval"), allow_none=True, ) @@ -304,52 +305,3 @@ class TransactionStore(SQLBaseStore): txn.execute(query, (self._clock.time_msec(),)) return self.cursor_to_dict(txn) - - -class ReceivedTransactionsTable(object): - table_name = "received_transactions" - - fields = [ - "transaction_id", - "origin", - "ts", - "response_code", - "response_json", - "has_been_referenced", - ] - - -class SentTransactions(object): - table_name = "sent_transactions" - - fields = [ - "id", - "transaction_id", - "destination", - "ts", - "response_code", - "response_json", - ] - - EntryType = namedtuple("SentTransactionsEntry", fields) - - -class TransactionsToPduTable(object): - table_name = "transaction_id_to_pdu" - - fields = [ - "transaction_id", - "destination", - "pdu_id", - "pdu_origin", - ] - - -class DestinationsTable(object): - table_name = "destinations" - - fields = [ - "destination", - "retry_last_ts", - "retry_interval", - ] diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 15000aae0c..447a22b5fc 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -28,7 +28,6 @@ from synapse.api.constants import PresenceState from synapse.api.errors import SynapseError from synapse.handlers.presence import PresenceHandler, UserPresenceCache from synapse.streams.config import SourcePaginationConfig -from synapse.storage.transactions import DestinationsTable from synapse.types import UserID OFFLINE = PresenceState.OFFLINE diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 124bc10e0f..763c04d667 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -27,7 +27,6 @@ from ..utils import ( from synapse.api.errors import AuthError from synapse.handlers.typing import TypingNotificationHandler -from synapse.storage.transactions import DestinationsTable from synapse.types import UserID -- cgit 1.5.1 From 2680043bc6a64053b93b9bab144aeb5f45007976 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 14 Jan 2016 14:34:01 +0000 Subject: Require ID and as_token be unique for ASs Defaults ID to as_token if not specified. This will change when IDs are fully supported. --- synapse/storage/appservice.py | 26 +++++++++- tests/appservice/test_appservice.py | 1 + tests/storage/test_appservice.py | 101 ++++++++++++++++++++++++++++++------ 3 files changed, 111 insertions(+), 17 deletions(-) (limited to 'tests') diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index f4bc457eca..b5aa55c0a3 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -20,6 +20,7 @@ from twisted.internet import defer from synapse.api.constants import Membership from synapse.appservice import ApplicationService, AppServiceTransaction +from synapse.config._base import ConfigError from synapse.storage.roommember import RoomsForUser from synapse.types import UserID from ._base import SQLBaseStore @@ -145,6 +146,7 @@ class ApplicationServiceStore(SQLBaseStore): def _load_appservice(self, as_info): required_string_fields = [ + # TODO: Add id here when it's stable to release "url", "as_token", "hs_token", "sender_localpart" ] for field in required_string_fields: @@ -186,7 +188,7 @@ class ApplicationServiceStore(SQLBaseStore): namespaces=as_info["namespaces"], hs_token=as_info["hs_token"], sender=user_id, - id=as_info["as_token"] # the token is the only unique thing here + id=as_info["id"] if "id" in as_info else as_info["as_token"], ) def _populate_appservice_cache(self, config_files): @@ -197,10 +199,32 @@ class ApplicationServiceStore(SQLBaseStore): ) return + # Dicts of value -> filename + seen_as_tokens = {} + seen_ids = {} + for config_file in config_files: try: with open(config_file, 'r') as f: appservice = self._load_appservice(yaml.load(f)) + if appservice.id in seen_ids: + raise ConfigError( + "Cannot reuse ID across application services: " + "%s (files: %s, %s)" % ( + appservice.id, config_file, seen_ids[appservice.id], + ) + ) + seen_ids[appservice.id] = config_file + if appservice.token in seen_as_tokens: + raise ConfigError( + "Cannot reuse as_token across application services: " + "%s (files: %s, %s)" % ( + appservice.token, + config_file, + seen_as_tokens[appservice.token], + ) + ) + seen_as_tokens[appservice.token] = config_file logger.info("Loaded application service: %s", appservice) self.services_cache.append(appservice) except Exception as e: diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 191c420c4d..ef48bbc296 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -29,6 +29,7 @@ class ApplicationServiceTestCase(unittest.TestCase): def setUp(self): self.service = ApplicationService( + id="unique_identifier", url="some_url", token="some_token", namespaces={ diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index a5a464640f..5abecdf6e0 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -12,12 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import tempfile +from synapse.config._base import ConfigError from tests import unittest from twisted.internet import defer from tests.utils import setup_test_homeserver from synapse.appservice import ApplicationService, ApplicationServiceState -from synapse.server import HomeServer from synapse.storage.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore ) @@ -26,7 +27,6 @@ import json import os import yaml from mock import Mock -from tests.utils import SQLiteMemoryDbPool, MockClock class ApplicationServiceStoreTestCase(unittest.TestCase): @@ -41,9 +41,16 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self.as_token = "token1" self.as_url = "some_url" - self._add_appservice(self.as_token, self.as_url, "some_hs_token", "bob") - self._add_appservice("token2", "some_url", "some_hs_token", "bob") - self._add_appservice("token3", "some_url", "some_hs_token", "bob") + self.as_id = "as1" + self._add_appservice( + self.as_token, + self.as_id, + self.as_url, + "some_hs_token", + "bob" + ) + self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") + self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts self.store = ApplicationServiceStore(hs) @@ -55,9 +62,9 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): except: pass - def _add_appservice(self, as_token, url, hs_token, sender): + def _add_appservice(self, as_token, id, url, hs_token, sender): as_yaml = dict(url=url, as_token=as_token, hs_token=hs_token, - sender_localpart=sender, namespaces={}) + id=id, sender_localpart=sender, namespaces={}) # use the token as the filename with open(as_token, 'w') as outfile: outfile.write(yaml.dump(as_yaml)) @@ -74,6 +81,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self.as_token ) self.assertEquals(stored_service.token, self.as_token) + self.assertEquals(stored_service.id, self.as_id) self.assertEquals(stored_service.url, self.as_url) self.assertEquals( stored_service.namespaces[ApplicationService.NS_ALIASES], @@ -110,34 +118,34 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): { "token": "token1", "url": "https://matrix-as.org", - "id": "token1" + "id": "id_1" }, { "token": "alpha_tok", "url": "https://alpha.com", - "id": "alpha_tok" + "id": "id_alpha" }, { "token": "beta_tok", "url": "https://beta.com", - "id": "beta_tok" + "id": "id_beta" }, { - "token": "delta_tok", - "url": "https://delta.com", - "id": "delta_tok" + "token": "gamma_tok", + "url": "https://gamma.com", + "id": "id_gamma" }, ] for s in self.as_list: - yield self._add_service(s["url"], s["token"]) + yield self._add_service(s["url"], s["token"], s["id"]) self.as_yaml_files = [] self.store = TestTransactionStore(hs) - def _add_service(self, url, as_token): + def _add_service(self, url, as_token, id): as_yaml = dict(url=url, as_token=as_token, hs_token="something", - sender_localpart="a_sender", namespaces={}) + id=id, sender_localpart="a_sender", namespaces={}) # use the token as the filename with open(as_token, 'w') as outfile: outfile.write(yaml.dump(as_yaml)) @@ -405,3 +413,64 @@ class TestTransactionStore(ApplicationServiceTransactionStore, def __init__(self, hs): super(TestTransactionStore, self).__init__(hs) + + +class ApplicationServiceStoreConfigTestCase(unittest.TestCase): + + def _write_config(self, suffix, **kwargs): + vals = { + "id": "id" + suffix, + "url": "url" + suffix, + "as_token": "as_token" + suffix, + "hs_token": "hs_token" + suffix, + "sender_localpart": "sender_localpart" + suffix, + "namespaces": {}, + } + vals.update(kwargs) + + _, path = tempfile.mkstemp(prefix="as_config") + with open(path, "w") as f: + f.write(yaml.dump(vals)) + return path + + @defer.inlineCallbacks + def test_unique_works(self): + f1 = self._write_config(suffix="1") + f2 = self._write_config(suffix="2") + + config = Mock(app_service_config_files=[f1, f2]) + hs = yield setup_test_homeserver(config=config) + + ApplicationServiceStore(hs) + + @defer.inlineCallbacks + def test_duplicate_ids(self): + f1 = self._write_config(id="id", suffix="1") + f2 = self._write_config(id="id", suffix="2") + + config = Mock(app_service_config_files=[f1, f2]) + hs = yield setup_test_homeserver(config=config) + + with self.assertRaises(ConfigError) as cm: + ApplicationServiceStore(hs) + + e = cm.exception + self.assertIn(f1, e.message) + self.assertIn(f2, e.message) + self.assertIn("id", e.message) + + @defer.inlineCallbacks + def test_duplicate_as_tokens(self): + f1 = self._write_config(as_token="as_token", suffix="1") + f2 = self._write_config(as_token="as_token", suffix="2") + + config = Mock(app_service_config_files=[f1, f2]) + hs = yield setup_test_homeserver(config=config) + + with self.assertRaises(ConfigError) as cm: + ApplicationServiceStore(hs) + + e = cm.exception + self.assertIn(f1, e.message) + self.assertIn(f2, e.message) + self.assertIn("as_token", e.message) -- cgit 1.5.1 From ac5a4477adc772e4416c868e8b16ae41a2c0c4ef Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Fri, 15 Jan 2016 16:27:26 +0000 Subject: Require unbanning before other membership changes --- synapse/api/errors.py | 1 + synapse/handlers/federation.py | 4 +-- synapse/handlers/message.py | 57 +++++++++++++++++++++++++++++++++--------- synapse/handlers/room.py | 55 ++++++++++++++++++++++++++++++++++++++-- synapse/rest/client/v1/room.py | 51 +++++++++---------------------------- tests/handlers/test_room.py | 6 ++--- 6 files changed, 116 insertions(+), 58 deletions(-) (limited to 'tests') diff --git a/synapse/api/errors.py b/synapse/api/errors.py index ce0fc53668..b106fbed6d 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -29,6 +29,7 @@ class Codes(object): USER_IN_USE = "M_USER_IN_USE" ROOM_IN_USE = "M_ROOM_IN_USE" BAD_PAGINATION = "M_BAD_PAGINATION" + BAD_STATE = "M_BAD_STATE" UNKNOWN = "M_UNKNOWN" NOT_FOUND = "M_NOT_FOUND" MISSING_TOKEN = "M_MISSING_TOKEN" diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2f6359c768..26402ea9cd 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1693,7 +1693,7 @@ class FederationHandler(BaseHandler): self.auth.check(event, context.current_state) yield self._validate_keyserver(event, auth_events=context.current_state) member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.change_membership(event, context) + yield member_handler.send_membership_event(event, context) else: destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)]) yield self.replication_layer.forward_third_party_invite( @@ -1722,7 +1722,7 @@ class FederationHandler(BaseHandler): # TODO: Make sure the signatures actually are correct. event.signatures.update(returned_invite.signatures) member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.change_membership(event, context) + yield member_handler.send_membership_event(event, context) @defer.inlineCallbacks def add_display_name_to_third_party_invite(self, event_dict, event, context): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5805190ce8..4c7bf2bef3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -174,30 +174,25 @@ class MessageHandler(BaseHandler): defer.returnValue(chunk) @defer.inlineCallbacks - def create_and_send_event(self, event_dict, ratelimit=True, - token_id=None, txn_id=None, is_guest=False): - """ Given a dict from a client, create and handle a new event. + def create_event(self, event_dict, token_id=None, txn_id=None): + """ + Given a dict from a client, create a new event. Creates an FrozenEvent object, filling out auth_events, prev_events, etc. Adds display names to Join membership events. - Persists and notifies local clients and federation. - Args: event_dict (dict): An entire event + + Returns: + Tuple of created event (FrozenEvent), Context """ builder = self.event_builder_factory.new(event_dict) self.validator.validate_new(builder) - if ratelimit: - self.ratelimit(builder.user_id) - # TODO(paul): Why does 'event' not have a 'user' object? - user = UserID.from_string(builder.user_id) - assert self.hs.is_mine(user), "User must be our own: %s" % (user,) - if builder.type == EventTypes.Member: membership = builder.content.get("membership", None) if membership == Membership.JOIN: @@ -216,6 +211,25 @@ class MessageHandler(BaseHandler): event, context = yield self._create_new_client_event( builder=builder, ) + defer.returnValue((event, context)) + + @defer.inlineCallbacks + def send_event(self, event, context, ratelimit=True, is_guest=False): + """ + Persists and notifies local clients and federation of an event. + + Args: + event (FrozenEvent) the event to send. + context (Context) the context of the event. + ratelimit (bool): Whether to rate limit this send. + is_guest (bool): Whether the sender is a guest. + """ + user = UserID.from_string(event.sender) + + assert self.hs.is_mine(user), "User must be our own: %s" % (user,) + + if ratelimit: + self.ratelimit(event.sender) if event.is_state(): prev_state = context.current_state.get((event.type, event.state_key)) @@ -229,7 +243,7 @@ class MessageHandler(BaseHandler): if event.type == EventTypes.Member: member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.change_membership(event, context, is_guest=is_guest) + yield member_handler.send_membership_event(event, context, is_guest=is_guest) else: yield self.handle_new_client_event( event=event, @@ -241,6 +255,25 @@ class MessageHandler(BaseHandler): with PreserveLoggingContext(): presence.bump_presence_active_time(user) + @defer.inlineCallbacks + def create_and_send_event(self, event_dict, ratelimit=True, + token_id=None, txn_id=None, is_guest=False): + """ + Creates an event, then sends it. + + See self.create_event and self.send_event. + """ + event, context = yield self.create_event( + event_dict, + token_id=token_id, + txn_id=txn_id + ) + yield self.send_event( + event, + context, + ratelimit=ratelimit, + is_guest=is_guest + ) defer.returnValue(event) @defer.inlineCallbacks diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a410e4394c..a1baf9d200 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -22,7 +22,7 @@ from synapse.types import UserID, RoomAlias, RoomID from synapse.api.constants import ( EventTypes, Membership, JoinRules, RoomCreationPreset, ) -from synapse.api.errors import AuthError, StoreError, SynapseError +from synapse.api.errors import AuthError, StoreError, SynapseError, Codes from synapse.util import stringutils, unwrapFirstError from synapse.util.async import run_on_reactor @@ -397,7 +397,58 @@ class RoomMemberHandler(BaseHandler): remotedomains.add(member.domain) @defer.inlineCallbacks - def change_membership(self, event, context, is_guest=False): + def update_membership(self, requester, target, room_id, action, txn_id=None): + effective_membership_state = action + if action in ["kick", "unban"]: + effective_membership_state = "leave" + elif action == "forget": + effective_membership_state = "leave" + + msg_handler = self.hs.get_handlers().message_handler + + content = {"membership": unicode(effective_membership_state)} + if requester.is_guest: + content["kind"] = "guest" + + event, context = yield msg_handler.create_event( + { + "type": EventTypes.Member, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "state_key": target.to_string(), + }, + token_id=requester.access_token_id, + txn_id=txn_id, + ) + + old_state = context.current_state.get((EventTypes.Member, event.state_key)) + old_membership = old_state.content.get("membership") if old_state else None + if action == "unban" and old_membership != "ban": + raise SynapseError( + 403, + "Cannot unban user who was not banned (membership=%s)" % old_membership, + errcode=Codes.BAD_STATE + ) + if old_membership == "ban" and action != "unban": + raise SynapseError( + 403, + "Cannot %s user who was is banned" % (action,), + errcode=Codes.BAD_STATE + ) + + yield msg_handler.send_event( + event, + context, + ratelimit=True, + is_guest=requester.is_guest + ) + + if action == "forget": + yield self.forget(requester.user, room_id) + + @defer.inlineCallbacks + def send_membership_event(self, event, context, is_guest=False): """ Change the membership status of a user in a room. Args: diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 8b1b2b852d..85b9f253e3 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -442,7 +442,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): def register(self, http_server): # /rooms/$roomid/[invite|join|leave] PATTERNS = ("/rooms/(?P[^/]*)/" - "(?Pjoin|invite|leave|ban|kick|forget)") + "(?Pjoin|invite|leave|ban|unban|kick|forget)") register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks @@ -451,9 +451,6 @@ class RoomMembershipRestServlet(ClientV1RestServlet): request, allow_guest=True, ) - user = requester.user - - effective_membership_action = membership_action if requester.is_guest and membership_action not in { Membership.JOIN, @@ -463,13 +460,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet): content = _parse_json(request) - # target user is you unless it is an invite - state_key = user.to_string() - if membership_action == "invite" and self._has_3pid_invite_keys(content): yield self.handlers.room_member_handler.do_3pid_invite( room_id, - user, + requester.user, content["medium"], content["address"], content["id_server"], @@ -478,42 +472,21 @@ class RoomMembershipRestServlet(ClientV1RestServlet): ) defer.returnValue((200, {})) return - elif membership_action in ["invite", "ban", "kick"]: - if "user_id" in content: - state_key = content["user_id"] - else: - raise SynapseError(400, "Missing user_id key.") - - # make sure it looks like a user ID; it'll throw if it's invalid. - UserID.from_string(state_key) - if membership_action == "kick": - effective_membership_action = "leave" - elif membership_action == "forget": - effective_membership_action = "leave" - - msg_handler = self.handlers.message_handler - - content = {"membership": unicode(effective_membership_action)} - if requester.is_guest: - content["kind"] = "guest" + target = requester.user + if membership_action in ["invite", "ban", "unban", "kick"]: + if "user_id" not in content: + raise SynapseError(400, "Missing user_id key.") + target = UserID.from_string(content["user_id"]) - yield msg_handler.create_and_send_event( - { - "type": EventTypes.Member, - "content": content, - "room_id": room_id, - "sender": user.to_string(), - "state_key": state_key, - }, - token_id=requester.access_token_id, + yield self.handlers.room_member_handler.update_membership( + requester=requester, + target=target, + room_id=room_id, + action=membership_action, txn_id=txn_id, - is_guest=requester.is_guest, ) - if membership_action == "forget": - yield self.handlers.room_member_handler.forget(user, room_id) - defer.returnValue((200, {})) def _has_3pid_invite_keys(self, content): diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py index 97491848a3..e7a12a2ba2 100644 --- a/tests/handlers/test_room.py +++ b/tests/handlers/test_room.py @@ -156,7 +156,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): builder ) - yield room_handler.change_membership(event, context) + yield room_handler.send_membership_event(event, context) self.state_handler.compute_event_context.assert_called_once_with( builder @@ -232,7 +232,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): ) # Actual invocation - yield room_handler.change_membership(event, context) + yield room_handler.send_membership_event(event, context) self.federation.handle_new_event.assert_called_once_with( event, destinations=set() @@ -312,7 +312,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): self.distributor.observe("user_left_room", leave_signal_observer) # Actual invocation - yield room_handler.change_membership(event, context) + yield room_handler.send_membership_event(event, context) self.federation.handle_new_event.assert_called_once_with( event, destinations=set(['red']) -- cgit 1.5.1 From 2c176e02ae910ce52197539b31f78ae1b1ef4c3c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 18 Jan 2016 14:24:31 +0000 Subject: Make unit tests work --- synapse/storage/push_rule.py | 2 +- tests/handlers/test_federation.py | 141 ------------- tests/handlers/test_room.py | 418 -------------------------------------- 3 files changed, 1 insertion(+), 560 deletions(-) delete mode 100644 tests/handlers/test_federation.py delete mode 100644 tests/handlers/test_room.py (limited to 'tests') diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 1adf28b893..f210e6c14d 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -14,7 +14,7 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from twisted.internet import defer import logging diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py deleted file mode 100644 index 11a3d94bb0..0000000000 --- a/tests/handlers/test_federation.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer -from tests import unittest - -from synapse.api.constants import EventTypes -from synapse.events import FrozenEvent -from synapse.handlers.federation import FederationHandler - -from mock import NonCallableMock, ANY, Mock - -from ..utils import setup_test_homeserver - - -class FederationTestCase(unittest.TestCase): - - @defer.inlineCallbacks - def setUp(self): - - self.state_handler = NonCallableMock(spec_set=[ - "compute_event_context", - ]) - - self.auth = NonCallableMock(spec_set=[ - "check", - "check_host_in_room", - ]) - - self.hostname = "test" - hs = yield setup_test_homeserver( - self.hostname, - datastore=NonCallableMock(spec_set=[ - "persist_event", - "store_room", - "get_room", - "get_destination_retry_timings", - "set_destination_retry_timings", - "have_events", - "get_users_in_room", - "bulk_get_push_rules", - "get_current_state", - "set_push_actions_for_event_and_users", - "is_guest", - "get_state_for_events", - ]), - resource_for_federation=NonCallableMock(), - http_client=NonCallableMock(spec_set=[]), - notifier=NonCallableMock(spec_set=["on_new_room_event"]), - handlers=NonCallableMock(spec_set=[ - "room_member_handler", - "federation_handler", - ]), - auth=self.auth, - state_handler=self.state_handler, - keyring=Mock(), - ) - - self.datastore = hs.get_datastore() - self.handlers = hs.get_handlers() - self.notifier = hs.get_notifier() - self.hs = hs - - self.handlers.federation_handler = FederationHandler(self.hs) - - self.datastore.get_state_for_events.return_value = {"$a:b": {}} - - @defer.inlineCallbacks - def test_msg(self): - pdu = FrozenEvent({ - "type": EventTypes.Message, - "room_id": "foo", - "content": {"msgtype": u"fooo"}, - "origin_server_ts": 0, - "event_id": "$a:b", - "user_id":"@a:b", - "origin": "b", - "auth_events": [], - "hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"}, - }) - - self.datastore.persist_event.return_value = defer.succeed((1,1)) - self.datastore.get_room.return_value = defer.succeed(True) - self.datastore.get_users_in_room.return_value = ["@a:b"] - self.datastore.bulk_get_push_rules.return_value = {} - self.datastore.get_current_state.return_value = {} - self.auth.check_host_in_room.return_value = defer.succeed(True) - - retry_timings_res = { - "destination": "", - "retry_last_ts": 0, - "retry_interval": 0, - } - self.datastore.get_destination_retry_timings.return_value = ( - defer.succeed(retry_timings_res) - ) - - def have_events(event_ids): - return defer.succeed({}) - self.datastore.have_events.side_effect = have_events - - def annotate(ev, old_state=None, outlier=False): - context = Mock() - context.current_state = {} - context.auth_events = {} - return defer.succeed(context) - self.state_handler.compute_event_context.side_effect = annotate - - yield self.handlers.federation_handler.on_receive_pdu( - "fo", pdu, False - ) - - self.datastore.persist_event.assert_called_once_with( - ANY, - is_new_state=True, - backfilled=False, - current_state=None, - context=ANY, - ) - - self.state_handler.compute_event_context.assert_called_once_with( - ANY, old_state=None, outlier=False - ) - - self.auth.check.assert_called_once_with(ANY, auth_events={}) - - self.notifier.on_new_room_event.assert_called_once_with( - ANY, 1, 1, extra_users=[] - ) diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py deleted file mode 100644 index e7a12a2ba2..0000000000 --- a/tests/handlers/test_room.py +++ /dev/null @@ -1,418 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer -from .. import unittest - -from synapse.api.constants import EventTypes, Membership -from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler -from synapse.handlers.profile import ProfileHandler -from synapse.types import UserID -from ..utils import setup_test_homeserver - -from mock import Mock, NonCallableMock - - -class RoomMemberHandlerTestCase(unittest.TestCase): - - @defer.inlineCallbacks - def setUp(self): - self.hostname = "red" - hs = yield setup_test_homeserver( - self.hostname, - ratelimiter=NonCallableMock(spec_set=[ - "send_message", - ]), - datastore=NonCallableMock(spec_set=[ - "persist_event", - "get_room_member", - "get_room", - "store_room", - "get_latest_events_in_room", - "add_event_hashes", - "get_users_in_room", - "bulk_get_push_rules", - "get_current_state", - "set_push_actions_for_event_and_users", - "get_state_for_events", - "is_guest", - ]), - resource_for_federation=NonCallableMock(), - http_client=NonCallableMock(spec_set=[]), - notifier=NonCallableMock(spec_set=["on_new_room_event"]), - handlers=NonCallableMock(spec_set=[ - "room_member_handler", - "profile_handler", - "federation_handler", - ]), - auth=NonCallableMock(spec_set=[ - "check", - "add_auth_events", - "check_host_in_room", - ]), - state_handler=NonCallableMock(spec_set=[ - "compute_event_context", - "get_current_state", - ]), - ) - - self.federation = NonCallableMock(spec_set=[ - "handle_new_event", - "send_invite", - "get_state_for_room", - ]) - - self.datastore = hs.get_datastore() - self.handlers = hs.get_handlers() - self.notifier = hs.get_notifier() - self.state_handler = hs.get_state_handler() - self.distributor = hs.get_distributor() - self.auth = hs.get_auth() - self.hs = hs - - self.handlers.federation_handler = self.federation - - self.distributor.declare("collect_presencelike_data") - - self.handlers.room_member_handler = RoomMemberHandler(self.hs) - self.handlers.profile_handler = ProfileHandler(self.hs) - self.room_member_handler = self.handlers.room_member_handler - - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - self.datastore.persist_event.return_value = (1,1) - self.datastore.add_event_hashes.return_value = [] - self.datastore.get_users_in_room.return_value = ["@bob:red"] - self.datastore.bulk_get_push_rules.return_value = {} - - @defer.inlineCallbacks - def test_invite(self): - room_id = "!foo:red" - user_id = "@bob:red" - target_user_id = "@red:blue" - content = {"membership": Membership.INVITE} - - builder = self.hs.get_event_builder_factory().new({ - "type": EventTypes.Member, - "sender": user_id, - "state_key": target_user_id, - "room_id": room_id, - "content": content, - }) - - self.datastore.get_latest_events_in_room.return_value = ( - defer.succeed([]) - ) - self.datastore.get_current_state.return_value = {} - self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids} - - def annotate(_): - ctx = Mock() - ctx.current_state = { - (EventTypes.Member, "@alice:green"): self._create_member( - user_id="@alice:green", - room_id=room_id, - ), - (EventTypes.Member, "@bob:red"): self._create_member( - user_id="@bob:red", - room_id=room_id, - ), - } - ctx.prev_state_events = [] - - return defer.succeed(ctx) - - self.state_handler.compute_event_context.side_effect = annotate - - def add_auth(_, ctx): - ctx.auth_events = ctx.current_state[ - (EventTypes.Member, "@bob:red") - ] - - return defer.succeed(True) - self.auth.add_auth_events.side_effect = add_auth - - def send_invite(domain, event): - return defer.succeed(event) - - self.federation.send_invite.side_effect = send_invite - - room_handler = self.room_member_handler - event, context = yield room_handler._create_new_client_event( - builder - ) - - yield room_handler.send_membership_event(event, context) - - self.state_handler.compute_event_context.assert_called_once_with( - builder - ) - - self.auth.add_auth_events.assert_called_once_with( - builder, context - ) - - self.federation.send_invite.assert_called_once_with( - "blue", event, - ) - - self.datastore.persist_event.assert_called_once_with( - event, context=context, - ) - self.notifier.on_new_room_event.assert_called_once_with( - event, 1, 1, extra_users=[UserID.from_string(target_user_id)] - ) - self.assertFalse(self.datastore.get_room.called) - self.assertFalse(self.datastore.store_room.called) - self.assertFalse(self.federation.get_state_for_room.called) - - @defer.inlineCallbacks - def test_simple_join(self): - room_id = "!foo:red" - user_id = "@bob:red" - user = UserID.from_string(user_id) - - join_signal_observer = Mock() - self.distributor.observe("user_joined_room", join_signal_observer) - - builder = self.hs.get_event_builder_factory().new({ - "type": EventTypes.Member, - "sender": user_id, - "state_key": user_id, - "room_id": room_id, - "content": {"membership": Membership.JOIN}, - }) - - self.datastore.get_latest_events_in_room.return_value = ( - defer.succeed([]) - ) - self.datastore.get_current_state.return_value = {} - self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids} - - def annotate(_): - ctx = Mock() - ctx.current_state = { - (EventTypes.Member, "@bob:red"): self._create_member( - user_id="@bob:red", - room_id=room_id, - membership=Membership.INVITE - ), - } - ctx.prev_state_events = [] - - return defer.succeed(ctx) - - self.state_handler.compute_event_context.side_effect = annotate - - def add_auth(_, ctx): - ctx.auth_events = ctx.current_state[ - (EventTypes.Member, "@bob:red") - ] - - return defer.succeed(True) - self.auth.add_auth_events.side_effect = add_auth - - room_handler = self.room_member_handler - event, context = yield room_handler._create_new_client_event( - builder - ) - - # Actual invocation - yield room_handler.send_membership_event(event, context) - - self.federation.handle_new_event.assert_called_once_with( - event, destinations=set() - ) - - self.datastore.persist_event.assert_called_once_with( - event, context=context - ) - self.notifier.on_new_room_event.assert_called_once_with( - event, 1, 1, extra_users=[user] - ) - - join_signal_observer.assert_called_with( - user=user, room_id=room_id - ) - - def _create_member(self, user_id, room_id, membership=Membership.JOIN): - builder = self.hs.get_event_builder_factory().new({ - "type": EventTypes.Member, - "sender": user_id, - "state_key": user_id, - "room_id": room_id, - "content": {"membership": membership}, - }) - - return builder.build() - - @defer.inlineCallbacks - def test_simple_leave(self): - room_id = "!foo:red" - user_id = "@bob:red" - user = UserID.from_string(user_id) - - builder = self.hs.get_event_builder_factory().new({ - "type": EventTypes.Member, - "sender": user_id, - "state_key": user_id, - "room_id": room_id, - "content": {"membership": Membership.LEAVE}, - }) - - self.datastore.get_latest_events_in_room.return_value = ( - defer.succeed([]) - ) - self.datastore.get_current_state.return_value = {} - self.datastore.get_state_for_events = lambda event_ids,types: {x: {} for x in event_ids} - - def annotate(_): - ctx = Mock() - ctx.current_state = { - (EventTypes.Member, "@bob:red"): self._create_member( - user_id="@bob:red", - room_id=room_id, - membership=Membership.JOIN - ), - } - ctx.prev_state_events = [] - - return defer.succeed(ctx) - - self.state_handler.compute_event_context.side_effect = annotate - - def add_auth(_, ctx): - ctx.auth_events = ctx.current_state[ - (EventTypes.Member, "@bob:red") - ] - - return defer.succeed(True) - self.auth.add_auth_events.side_effect = add_auth - - room_handler = self.room_member_handler - event, context = yield room_handler._create_new_client_event( - builder - ) - - leave_signal_observer = Mock() - self.distributor.observe("user_left_room", leave_signal_observer) - - # Actual invocation - yield room_handler.send_membership_event(event, context) - - self.federation.handle_new_event.assert_called_once_with( - event, destinations=set(['red']) - ) - - self.datastore.persist_event.assert_called_once_with( - event, context=context - ) - self.notifier.on_new_room_event.assert_called_once_with( - event, 1, 1, extra_users=[user] - ) - - leave_signal_observer.assert_called_with( - user=user, room_id=room_id - ) - - -class RoomCreationTest(unittest.TestCase): - - @defer.inlineCallbacks - def setUp(self): - self.hostname = "red" - - hs = yield setup_test_homeserver( - self.hostname, - datastore=NonCallableMock(spec_set=[ - "store_room", - "snapshot_room", - "persist_event", - "get_joined_hosts_for_room", - ]), - http_client=NonCallableMock(spec_set=[]), - notifier=NonCallableMock(spec_set=["on_new_room_event"]), - handlers=NonCallableMock(spec_set=[ - "room_creation_handler", - "message_handler", - ]), - auth=NonCallableMock(spec_set=["check", "add_auth_events"]), - ratelimiter=NonCallableMock(spec_set=[ - "send_message", - ]), - ) - - self.federation = NonCallableMock(spec_set=[ - "handle_new_event", - ]) - - self.handlers = hs.get_handlers() - - self.handlers.room_creation_handler = RoomCreationHandler(hs) - self.room_creation_handler = self.handlers.room_creation_handler - - self.message_handler = self.handlers.message_handler - - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) - - @defer.inlineCallbacks - def test_room_creation(self): - user_id = "@foo:red" - room_id = "!bobs_room:red" - config = {"visibility": "private"} - - yield self.room_creation_handler.create_room( - user_id=user_id, - room_id=room_id, - config=config, - ) - - self.assertTrue(self.message_handler.create_and_send_event.called) - - event_dicts = [ - e[0][0] - for e in self.message_handler.create_and_send_event.call_args_list - ] - - self.assertTrue(len(event_dicts) > 3) - - self.assertDictContainsSubset( - { - "type": EventTypes.Create, - "sender": user_id, - "room_id": room_id, - }, - event_dicts[0] - ) - - self.assertEqual(user_id, event_dicts[0]["content"]["creator"]) - - self.assertDictContainsSubset( - { - "type": EventTypes.Member, - "sender": user_id, - "room_id": room_id, - "state_key": user_id, - }, - event_dicts[1] - ) - - self.assertEqual( - Membership.JOIN, - event_dicts[1]["content"]["membership"] - ) -- cgit 1.5.1 From 191070123da7f472bca99c0a89d27fbdca51f972 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 20 Jan 2016 11:34:09 +0000 Subject: Cache dns lookups, and use the cache if we fail to lookup servers later --- synapse/http/endpoint.py | 101 ++++++++++++++++++++++++++++------------- tests/test_dns.py | 115 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 186 insertions(+), 30 deletions(-) create mode 100644 tests/test_dns.py (limited to 'tests') diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 4341ded96a..a9e024a415 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -17,7 +17,7 @@ from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint from twisted.internet import defer from twisted.internet.error import ConnectError from twisted.names import client, dns -from twisted.names.error import DNSNameError +from twisted.names.error import DNSNameError, DomainError import collections import logging @@ -27,6 +27,14 @@ import random logger = logging.getLogger(__name__) +SERVER_CACHE = {} + + +_Server = collections.namedtuple( + "_Server", "priority weight host port" +) + + def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, timeout=None): """Construct an endpoint for the given matrix destination. @@ -73,10 +81,6 @@ class SRVClientEndpoint(object): Implements twisted.internet.interfaces.IStreamClientEndpoint. """ - _Server = collections.namedtuple( - "_Server", "priority weight host port" - ) - def __init__(self, reactor, service, domain, protocol="tcp", default_port=None, endpoint=TCP4ClientEndpoint, endpoint_kw_args={}): @@ -101,32 +105,8 @@ class SRVClientEndpoint(object): @defer.inlineCallbacks def fetch_servers(self): - try: - answers, auth, add = yield client.lookupService(self.service_name) - except DNSNameError: - answers = [] - - if (len(answers) == 1 - and answers[0].type == dns.SRV - and answers[0].payload - and answers[0].payload.target == dns.Name('.')): - raise ConnectError("Service %s unavailable", self.service_name) - - self.servers = [] self.used_servers = [] - - for answer in answers: - if answer.type != dns.SRV or not answer.payload: - continue - payload = answer.payload - self.servers.append(self._Server( - host=str(payload.target), - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight) - )) - - self.servers.sort() + self.servers = yield resolve_service(self.service_name) def pick_server(self): if not self.servers: @@ -170,3 +150,64 @@ class SRVClientEndpoint(object): ) connection = yield endpoint.connect(protocolFactory) defer.returnValue(connection) + + +@defer.inlineCallbacks +def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): + servers = [] + + try: + try: + answers, _, _ = yield dns_client.lookupService(service_name) + except DNSNameError: + defer.returnValue([]) + + if (len(answers) == 1 + and answers[0].type == dns.SRV + and answers[0].payload + and answers[0].payload.target == dns.Name('.')): + raise ConnectError("Service %s unavailable", service_name) + + for answer in answers: + if answer.type != dns.SRV or not answer.payload: + continue + + payload = answer.payload + + host = str(payload.target) + + try: + answers, _, _ = yield dns_client.lookupAddress(host) + except DNSNameError: + continue + + ips = [ + answer.payload.dottedQuad() + for answer in answers + if answer.type == dns.A and answer.payload + ] + + for ip in ips: + servers.append(_Server( + host=ip, + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight) + )) + + servers.sort() + cache[service_name] = list(servers) + except DomainError as e: + # We failed to resolve the name (other than a NameError) + # Try something in the cache, else rereaise + cache_entry = cache.get(service_name, None) + if cache_entry: + logger.warn( + "Failed to resolve %r, falling back to cache. %r", + service_name, e + ) + servers = list(cache_entry) + else: + raise e + + defer.returnValue(servers) diff --git a/tests/test_dns.py b/tests/test_dns.py new file mode 100644 index 0000000000..637b1606f8 --- /dev/null +++ b/tests/test_dns.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import unittest +from twisted.internet import defer +from twisted.names import dns, error + +from mock import Mock + +from synapse.http.endpoint import resolve_service + + +class DnsTestCase(unittest.TestCase): + + @defer.inlineCallbacks + def test_resolve(self): + dns_client_mock = Mock() + + service_name = "test_service.examle.com" + host_name = "example.com" + ip_address = "127.0.0.1" + + answer_srv = dns.RRHeader( + type=dns.SRV, + payload=dns.Record_SRV( + target=host_name, + ) + ) + + answer_a = dns.RRHeader( + type=dns.A, + payload=dns.Record_A( + address=ip_address, + ) + ) + + dns_client_mock.lookupService.return_value = ([answer_srv], None, None) + dns_client_mock.lookupAddress.return_value = ([answer_a], None, None) + + cache = {} + + servers = yield resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + + dns_client_mock.lookupService.assert_called_once_with(service_name) + dns_client_mock.lookupAddress.assert_called_once_with(host_name) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + self.assertEquals(servers[0].host, ip_address) + + @defer.inlineCallbacks + def test_from_cache(self): + dns_client_mock = Mock() + dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) + + service_name = "test_service.examle.com" + + cache = { + service_name: [object()] + } + + servers = yield resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + + dns_client_mock.lookupService.assert_called_once_with(service_name) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + + @defer.inlineCallbacks + def test_empty_cache(self): + dns_client_mock = Mock() + + dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) + + service_name = "test_service.examle.com" + + cache = {} + + with self.assertRaises(error.DNSServerError): + yield resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + + @defer.inlineCallbacks + def test_name_error(self): + dns_client_mock = Mock() + + dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError()) + + service_name = "test_service.examle.com" + + cache = {} + + servers = yield resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + + self.assertEquals(len(servers), 0) + self.assertEquals(len(cache), 0) -- cgit 1.5.1 From f1f81221205cf2ec101f96234050569d6419fd6b Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 21 Jan 2016 19:16:25 +0000 Subject: Change LRUCache to be tree-based so we can delete subtrees. --- synapse/push/push_rule_evaluator.py | 6 ++-- synapse/util/caches/descriptors.py | 11 ++++++- synapse/util/caches/dictionary_cache.py | 10 +++---- synapse/util/caches/lrucache.py | 43 ++++++++++++++++++++++----- synapse/util/caches/treecache.py | 52 +++++++++++++++++++++++++++++++++ tests/storage/test__base.py | 26 ++++++++--------- tests/util/test_lrucache.py | 44 ++++++++++++++-------------- 7 files changed, 140 insertions(+), 52 deletions(-) create mode 100644 synapse/util/caches/treecache.py (limited to 'tests') diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index dca018af95..27b0de4f66 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -309,14 +309,14 @@ def _flatten_dict(d, prefix=[], result={}): return result -regex_cache = LruCache(5000) +regex_cache = LruCache(5000, 1) def _compile_regex(regex_str): - r = regex_cache.get(regex_str, None) + r = regex_cache.get((regex_str,), None) if r: return r r = re.compile(regex_str, flags=re.IGNORECASE) - regex_cache[regex_str] = r + regex_cache[(regex_str,)] = r return r diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 0033051849..af7bf15500 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -38,7 +38,7 @@ class Cache(object): def __init__(self, name, max_entries=1000, keylen=1, lru=True): if lru: - self.cache = LruCache(max_size=max_entries) + self.cache = LruCache(max_size=max_entries, keylen=keylen) self.max_entries = None else: self.cache = OrderedDict() @@ -99,6 +99,15 @@ class Cache(object): self.sequence += 1 self.cache.pop(key, None) + def invalidate_many(self, key): + self.check_thread() + if not isinstance(key, tuple): + raise TypeError( + "The cache key must be a tuple not %r" % (type(key),) + ) + self.sequence += 1 + self.cache.del_multi(key) + def invalidate_all(self): self.check_thread() self.sequence += 1 diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index f92d80542b..b7964467eb 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -32,7 +32,7 @@ class DictionaryCache(object): """ def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries) + self.cache = LruCache(max_size=max_entries, keylen=1) self.name = name self.sequence = 0 @@ -56,7 +56,7 @@ class DictionaryCache(object): ) def get(self, key, dict_keys=None): - entry = self.cache.get(key, self.sentinel) + entry = self.cache.get((key,), self.sentinel) if entry is not self.sentinel: cache_counter.inc_hits(self.name) @@ -78,7 +78,7 @@ class DictionaryCache(object): # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 - self.cache.pop(key, None) + self.cache.pop((key,), None) def invalidate_all(self): self.check_thread() @@ -96,8 +96,8 @@ class DictionaryCache(object): self._update_or_insert(key, value) def _update_or_insert(self, key, value): - entry = self.cache.setdefault(key, DictionaryEntry(False, {})) + entry = self.cache.setdefault((key,), DictionaryEntry(False, {})) entry.value.update(value) def _insert(self, key, value): - self.cache[key] = DictionaryEntry(True, value) + self.cache[(key,)] = DictionaryEntry(True, value) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 0122b0bb3f..0feceb298a 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -17,11 +17,23 @@ from functools import wraps import threading +from synapse.util.caches.treecache import TreeCache + + +def enumerate_leaves(node, depth): + if depth == 0: + yield node + else: + for n in node.values(): + for m in enumerate_leaves(n, depth - 1): + yield m + class LruCache(object): """Least-recently-used cache.""" - def __init__(self, max_size): - cache = {} + def __init__(self, max_size, keylen): + cache = TreeCache() + self.size = 0 list_root = [] list_root[:] = [list_root, list_root, None, None] @@ -44,6 +56,7 @@ class LruCache(object): prev_node[NEXT] = node next_node[PREV] = node cache[key] = node + self.size += 1 def move_node_to_front(node): prev_node = node[PREV] @@ -62,7 +75,7 @@ class LruCache(object): next_node = node[NEXT] prev_node[NEXT] = next_node next_node[PREV] = prev_node - cache.pop(node[KEY], None) + self.size -= 1 @synchronized def cache_get(key, default=None): @@ -81,8 +94,10 @@ class LruCache(object): node[VALUE] = value else: add_node(key, value) - if len(cache) > max_size: - delete_node(list_root[PREV]) + if self.size > max_size: + todelete = list_root[PREV] + delete_node(todelete) + cache.pop(todelete[KEY], None) @synchronized def cache_set_default(key, value): @@ -91,8 +106,10 @@ class LruCache(object): return node[VALUE] else: add_node(key, value) - if len(cache) > max_size: - delete_node(list_root[PREV]) + if self.size > max_size: + todelete = list_root[PREV] + delete_node(todelete) + cache.pop(todelete[KEY], None) return value @synchronized @@ -100,10 +117,19 @@ class LruCache(object): node = cache.get(key, None) if node: delete_node(node) + cache.pop(node[KEY], None) return node[VALUE] else: return default + @synchronized + def cache_del_multi(key): + popped = cache.pop(key) + if popped is None: + return + for leaf in enumerate_leaves(popped, keylen - len(key)): + delete_node(leaf) + @synchronized def cache_clear(): list_root[NEXT] = list_root @@ -112,7 +138,7 @@ class LruCache(object): @synchronized def cache_len(): - return len(cache) + return self.size @synchronized def cache_contains(key): @@ -123,6 +149,7 @@ class LruCache(object): self.set = cache_set self.setdefault = cache_set_default self.pop = cache_pop + self.del_multi = cache_del_multi self.len = cache_len self.contains = cache_contains self.clear = cache_clear diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py new file mode 100644 index 0000000000..1e5f87e6ad --- /dev/null +++ b/synapse/util/caches/treecache.py @@ -0,0 +1,52 @@ +SENTINEL = object() + + +class TreeCache(object): + def __init__(self): + self.root = {} + + def __setitem__(self, key, value): + return self.set(key, value) + + def set(self, key, value): + node = self.root + for k in key[:-1]: + node = node.setdefault(k, {}) + node[key[-1]] = value + + def get(self, key, default=None): + node = self.root + for k in key[:-1]: + node = node.get(k, None) + if node is None: + return default + return node.get(key[-1], default) + + def clear(self): + self.root = {} + + def pop(self, key, default=None): + nodes = [] + + node = self.root + for k in key[:-1]: + node = node.get(k, None) + nodes.append(node) # don't add the root node + if node is None: + return default + popped = node.pop(key[-1], SENTINEL) + if popped is SENTINEL: + return default + + node_and_keys = zip(nodes, key) + node_and_keys.reverse() + node_and_keys.append((self.root, None)) + + for i in range(len(node_and_keys) - 1): + n,k = node_and_keys[i] + + if n: + break + node_and_keys[i+1][0].pop(k) + + return popped \ No newline at end of file diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 219288621d..c4e4c9b4bf 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -56,42 +56,42 @@ class CacheTestCase(unittest.TestCase): def test_eviction(self): cache = Cache("test", max_entries=2) - cache.prefill(1, "one") - cache.prefill(2, "two") - cache.prefill(3, "three") # 1 will be evicted + cache.prefill((1,), "one") + cache.prefill((2,), "two") + cache.prefill((3,), "three") # 1 will be evicted failed = False try: - cache.get(1) + cache.get((1,)) except KeyError: failed = True self.assertTrue(failed) - cache.get(2) - cache.get(3) + cache.get((2,)) + cache.get((3,)) def test_eviction_lru(self): cache = Cache("test", max_entries=2, lru=True) - cache.prefill(1, "one") - cache.prefill(2, "two") + cache.prefill((1,), "one") + cache.prefill((2,), "two") # Now access 1 again, thus causing 2 to be least-recently used - cache.get(1) + cache.get((1,)) - cache.prefill(3, "three") + cache.prefill((3,), "three") failed = False try: - cache.get(2) + cache.get((2,)) except KeyError: failed = True self.assertTrue(failed) - cache.get(1) - cache.get(3) + cache.get((1,)) + cache.get((3,)) class CacheDecoratorTestCase(unittest.TestCase): diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index fbbc5eed15..80c19b944a 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -21,34 +21,34 @@ from synapse.util.caches.lrucache import LruCache class LruCacheTestCase(unittest.TestCase): def test_get_set(self): - cache = LruCache(1) - cache["key"] = "value" - self.assertEquals(cache.get("key"), "value") - self.assertEquals(cache["key"], "value") + cache = LruCache(1, 1) + cache[("key",)] = "value" + self.assertEquals(cache.get(("key",)), "value") + self.assertEquals(cache[("key",)], "value") def test_eviction(self): - cache = LruCache(2) - cache[1] = 1 - cache[2] = 2 + cache = LruCache(2, 1) + cache[(1,)] = 1 + cache[(2,)] = 2 - self.assertEquals(cache.get(1), 1) - self.assertEquals(cache.get(2), 2) + self.assertEquals(cache.get((1,)), 1) + self.assertEquals(cache.get((2,)), 2) - cache[3] = 3 + cache[(3,)] = 3 - self.assertEquals(cache.get(1), None) - self.assertEquals(cache.get(2), 2) - self.assertEquals(cache.get(3), 3) + self.assertEquals(cache.get((1,)), None) + self.assertEquals(cache.get((2,)), 2) + self.assertEquals(cache.get((3,)), 3) def test_setdefault(self): - cache = LruCache(1) - self.assertEquals(cache.setdefault("key", 1), 1) - self.assertEquals(cache.get("key"), 1) - self.assertEquals(cache.setdefault("key", 2), 1) - self.assertEquals(cache.get("key"), 1) + cache = LruCache(1, 1) + self.assertEquals(cache.setdefault(("key",), 1), 1) + self.assertEquals(cache.get(("key",)), 1) + self.assertEquals(cache.setdefault(("key",), 2), 1) + self.assertEquals(cache.get(("key",)), 1) def test_pop(self): - cache = LruCache(1) - cache["key"] = 1 - self.assertEquals(cache.pop("key"), 1) - self.assertEquals(cache.pop("key"), None) + cache = LruCache(1, 1) + cache[("key",)] = 1 + self.assertEquals(cache.pop(("key",)), 1) + self.assertEquals(cache.pop(("key",)), None) -- cgit 1.5.1 From 4efcaa43c8c69c7fdbaec74d7af2b71dbc6faea6 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 22 Jan 2016 10:37:37 +0000 Subject: Add tests for treecache directly and test del_multi at the LruCache level too. --- tests/util/test_treecache.py | 66 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 tests/util/test_treecache.py (limited to 'tests') diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py new file mode 100644 index 0000000000..9946ceb3f1 --- /dev/null +++ b/tests/util/test_treecache.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .. import unittest + +from synapse.util.caches.treecache import TreeCache + +class TreeCacheTestCase(unittest.TestCase): + def test_get_set_onelevel(self): + cache = TreeCache() + cache[("a",)] = "A" + cache[("b",)] = "B" + self.assertEquals(cache.get(("a",)), "A") + self.assertEquals(cache.get(("b",)), "B") + + def test_pop_onelevel(self): + cache = TreeCache() + cache[("a",)] = "A" + cache[("b",)] = "B" + self.assertEquals(cache.pop(("a",)), "A") + self.assertEquals(cache.pop(("a",)), None) + self.assertEquals(cache.get(("b",)), "B") + + def test_get_set_twolevel(self): + cache = TreeCache() + cache[("a", "a")] = "AA" + cache[("a", "b")] = "AB" + cache[("b", "a")] = "BA" + self.assertEquals(cache.get(("a", "a")), "AA") + self.assertEquals(cache.get(("a", "b")), "AB") + self.assertEquals(cache.get(("b", "a")), "BA") + + def test_pop_twolevel(self): + cache = TreeCache() + cache[("a", "a")] = "AA" + cache[("a", "b")] = "AB" + cache[("b", "a")] = "BA" + self.assertEquals(cache.pop(("a", "a")), "AA") + self.assertEquals(cache.get(("a", "a")), None) + self.assertEquals(cache.get(("a", "b")), "AB") + self.assertEquals(cache.pop(("b", "a")), "BA") + self.assertEquals(cache.pop(("b", "a")), None) + + def test_pop_mixedlevel(self): + cache = TreeCache() + cache[("a", "a")] = "AA" + cache[("a", "b")] = "AB" + cache[("b", "a")] = "BA" + self.assertEquals(cache.get(("a", "a")), "AA") + cache.pop(("a",)) + self.assertEquals(cache.get(("a", "a")), None) + self.assertEquals(cache.get(("a", "b")), None) + self.assertEquals(cache.get(("b", "a")), "BA") -- cgit 1.5.1 From 8f9c74e9f12050b6680355fc93758b28672e9358 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 22 Jan 2016 10:48:27 +0000 Subject: Fix tests --- tests/api/test_filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 14cddee679..16ee6bbe6a 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -504,4 +504,4 @@ class FilteringTestCase(unittest.TestCase): filter_id=filter_id, ) - self.assertEquals(filter.filter_json, user_filter_json) + self.assertEquals(filter.get_filter_json(), user_filter_json) -- cgit 1.5.1 From 31a051b6771bf720c78246bd6c8875d219ddbc88 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 22 Jan 2016 11:22:00 +0000 Subject: Test treecache directly --- tests/util/test_lrucache.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) (limited to 'tests') diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 80c19b944a..fca2e98983 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -52,3 +52,22 @@ class LruCacheTestCase(unittest.TestCase): cache[("key",)] = 1 self.assertEquals(cache.pop(("key",)), 1) self.assertEquals(cache.pop(("key",)), None) + + def test_del_multi(self): + cache = LruCache(4, 2) + cache[("animal", "cat")] = "mew" + cache[("animal", "dog")] = "woof" + cache[("vehicles", "car")] = "vroom" + cache[("vehicles", "train")] = "chuff" + + self.assertEquals(len(cache), 4) + + self.assertEquals(cache.get(("animal", "cat")), "mew") + self.assertEquals(cache.get(("vehicles", "car")), "vroom") + cache.del_multi(("animal",)) + self.assertEquals(len(cache), 2) + self.assertEquals(cache.get(("animal", "cat")), None) + self.assertEquals(cache.get(("animal", "dog")), None) + self.assertEquals(cache.get(("vehicles", "car")), "vroom") + self.assertEquals(cache.get(("vehicles", "train")), "chuff") + # Man from del_multi say "Yes". -- cgit 1.5.1 From 10f76dc5da47c49a4191d8113b3c0615224eb9fd Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 22 Jan 2016 12:10:33 +0000 Subject: Make LRU cache not default to treecache & add options to use it --- synapse/storage/event_push_actions.py | 2 +- synapse/util/caches/descriptors.py | 20 ++++++++++++++------ synapse/util/caches/lrucache.py | 9 +++++---- tests/util/test_lrucache.py | 3 ++- 4 files changed, 22 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 6a212c630b..a05c4f84cf 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -53,7 +53,7 @@ class EventPushActionsStore(SQLBaseStore): f, ) - @cachedInlineCallbacks(num_args=3, lru=True) + @cachedInlineCallbacks(num_args=3, lru=True, tree=True) def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id ): diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index f4a2b4e590..88e56e3302 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -17,6 +17,7 @@ import logging from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.treecache import TreeCache from . import caches_by_name, DEBUG_CACHES, cache_counter @@ -36,9 +37,12 @@ _CacheSentinel = object() class Cache(object): - def __init__(self, name, max_entries=1000, keylen=1, lru=True): + def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): if lru: - self.cache = LruCache(max_size=max_entries, keylen=keylen) + cache_type = TreeCache if tree else dict + self.cache = LruCache( + max_size=max_entries, keylen=keylen, cache_type=cache_type + ) self.max_entries = None else: self.cache = OrderedDict() @@ -131,7 +135,7 @@ class CacheDescriptor(object): which can be used to insert values into the cache specifically, without calling the calculation function. """ - def __init__(self, orig, max_entries=1000, num_args=1, lru=True, + def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False, inlineCallbacks=False): self.orig = orig @@ -143,6 +147,7 @@ class CacheDescriptor(object): self.max_entries = max_entries self.num_args = num_args self.lru = lru + self.tree = tree self.arg_names = inspect.getargspec(orig).args[1:num_args+1] @@ -158,6 +163,7 @@ class CacheDescriptor(object): max_entries=self.max_entries, keylen=self.num_args, lru=self.lru, + tree=self.tree, ) def __get__(self, obj, objtype=None): @@ -331,21 +337,23 @@ class CacheListDescriptor(object): return wrapped -def cached(max_entries=1000, num_args=1, lru=True): +def cached(max_entries=1000, num_args=1, lru=True, tree=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, - lru=lru + lru=lru, + tree=tree, ) -def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): +def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, num_args=num_args, lru=lru, + tree=tree, inlineCallbacks=True, ) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 0feceb298a..23e86ec110 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -17,8 +17,6 @@ from functools import wraps import threading -from synapse.util.caches.treecache import TreeCache - def enumerate_leaves(node, depth): if depth == 0: @@ -31,8 +29,8 @@ def enumerate_leaves(node, depth): class LruCache(object): """Least-recently-used cache.""" - def __init__(self, max_size, keylen): - cache = TreeCache() + def __init__(self, max_size, keylen, cache_type=dict): + cache = cache_type() self.size = 0 list_root = [] list_root[:] = [list_root, list_root, None, None] @@ -124,6 +122,9 @@ class LruCache(object): @synchronized def cache_del_multi(key): + """ + This will only work if constructed with cache_type=TreeCache + """ popped = cache.pop(key) if popped is None: return diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index fca2e98983..bcad1d4258 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -17,6 +17,7 @@ from .. import unittest from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.treecache import TreeCache class LruCacheTestCase(unittest.TestCase): @@ -54,7 +55,7 @@ class LruCacheTestCase(unittest.TestCase): self.assertEquals(cache.pop(("key",)), None) def test_del_multi(self): - cache = LruCache(4, 2) + cache = LruCache(4, 2, cache_type=TreeCache) cache[("animal", "cat")] = "mew" cache[("animal", "dog")] = "woof" cache[("vehicles", "car")] = "vroom" -- cgit 1.5.1 From d552861346d6f2f3d50fa0aff3e239d17cf9b7c0 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 22 Jan 2016 12:18:14 +0000 Subject: Revert all the bits changing keys of eeverything that used LRUCaches to tuples --- synapse/push/push_rule_evaluator.py | 6 ++--- synapse/util/caches/dictionary_cache.py | 10 ++++---- synapse/util/caches/lrucache.py | 2 +- tests/storage/test__base.py | 26 +++++++++---------- tests/util/test_lrucache.py | 44 ++++++++++++++++----------------- 5 files changed, 44 insertions(+), 44 deletions(-) (limited to 'tests') diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 27b0de4f66..dca018af95 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -309,14 +309,14 @@ def _flatten_dict(d, prefix=[], result={}): return result -regex_cache = LruCache(5000, 1) +regex_cache = LruCache(5000) def _compile_regex(regex_str): - r = regex_cache.get((regex_str,), None) + r = regex_cache.get(regex_str, None) if r: return r r = re.compile(regex_str, flags=re.IGNORECASE) - regex_cache[(regex_str,)] = r + regex_cache[regex_str] = r return r diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index b7964467eb..f92d80542b 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -32,7 +32,7 @@ class DictionaryCache(object): """ def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries, keylen=1) + self.cache = LruCache(max_size=max_entries) self.name = name self.sequence = 0 @@ -56,7 +56,7 @@ class DictionaryCache(object): ) def get(self, key, dict_keys=None): - entry = self.cache.get((key,), self.sentinel) + entry = self.cache.get(key, self.sentinel) if entry is not self.sentinel: cache_counter.inc_hits(self.name) @@ -78,7 +78,7 @@ class DictionaryCache(object): # Increment the sequence number so that any SELECT statements that # raced with the INSERT don't update the cache (SYN-369) self.sequence += 1 - self.cache.pop((key,), None) + self.cache.pop(key, None) def invalidate_all(self): self.check_thread() @@ -96,8 +96,8 @@ class DictionaryCache(object): self._update_or_insert(key, value) def _update_or_insert(self, key, value): - entry = self.cache.setdefault((key,), DictionaryEntry(False, {})) + entry = self.cache.setdefault(key, DictionaryEntry(False, {})) entry.value.update(value) def _insert(self, key, value): - self.cache[(key,)] = DictionaryEntry(True, value) + self.cache[key] = DictionaryEntry(True, value) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 23e86ec110..5f9405c95f 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -29,7 +29,7 @@ def enumerate_leaves(node, depth): class LruCache(object): """Least-recently-used cache.""" - def __init__(self, max_size, keylen, cache_type=dict): + def __init__(self, max_size, keylen=1, cache_type=dict): cache = cache_type() self.size = 0 list_root = [] diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index c4e4c9b4bf..219288621d 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -56,42 +56,42 @@ class CacheTestCase(unittest.TestCase): def test_eviction(self): cache = Cache("test", max_entries=2) - cache.prefill((1,), "one") - cache.prefill((2,), "two") - cache.prefill((3,), "three") # 1 will be evicted + cache.prefill(1, "one") + cache.prefill(2, "two") + cache.prefill(3, "three") # 1 will be evicted failed = False try: - cache.get((1,)) + cache.get(1) except KeyError: failed = True self.assertTrue(failed) - cache.get((2,)) - cache.get((3,)) + cache.get(2) + cache.get(3) def test_eviction_lru(self): cache = Cache("test", max_entries=2, lru=True) - cache.prefill((1,), "one") - cache.prefill((2,), "two") + cache.prefill(1, "one") + cache.prefill(2, "two") # Now access 1 again, thus causing 2 to be least-recently used - cache.get((1,)) + cache.get(1) - cache.prefill((3,), "three") + cache.prefill(3, "three") failed = False try: - cache.get((2,)) + cache.get(2) except KeyError: failed = True self.assertTrue(failed) - cache.get((1,)) - cache.get((3,)) + cache.get(1) + cache.get(3) class CacheDecoratorTestCase(unittest.TestCase): diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index bcad1d4258..2cd3d26454 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -22,37 +22,37 @@ from synapse.util.caches.treecache import TreeCache class LruCacheTestCase(unittest.TestCase): def test_get_set(self): - cache = LruCache(1, 1) - cache[("key",)] = "value" - self.assertEquals(cache.get(("key",)), "value") - self.assertEquals(cache[("key",)], "value") + cache = LruCache(1) + cache["key"] = "value" + self.assertEquals(cache.get("key"), "value") + self.assertEquals(cache["key"], "value") def test_eviction(self): - cache = LruCache(2, 1) - cache[(1,)] = 1 - cache[(2,)] = 2 + cache = LruCache(2) + cache[1] = 1 + cache[2] = 2 - self.assertEquals(cache.get((1,)), 1) - self.assertEquals(cache.get((2,)), 2) + self.assertEquals(cache.get(1), 1) + self.assertEquals(cache.get(2), 2) - cache[(3,)] = 3 + cache[3] = 3 - self.assertEquals(cache.get((1,)), None) - self.assertEquals(cache.get((2,)), 2) - self.assertEquals(cache.get((3,)), 3) + self.assertEquals(cache.get(1), None) + self.assertEquals(cache.get(2), 2) + self.assertEquals(cache.get(3), 3) def test_setdefault(self): - cache = LruCache(1, 1) - self.assertEquals(cache.setdefault(("key",), 1), 1) - self.assertEquals(cache.get(("key",)), 1) - self.assertEquals(cache.setdefault(("key",), 2), 1) - self.assertEquals(cache.get(("key",)), 1) + cache = LruCache(1) + self.assertEquals(cache.setdefault("key", 1), 1) + self.assertEquals(cache.get("key"), 1) + self.assertEquals(cache.setdefault("key", 2), 1) + self.assertEquals(cache.get("key"), 1) def test_pop(self): - cache = LruCache(1, 1) - cache[("key",)] = 1 - self.assertEquals(cache.pop(("key",)), 1) - self.assertEquals(cache.pop(("key",)), None) + cache = LruCache(1) + cache["key"] = 1 + self.assertEquals(cache.pop("key"), 1) + self.assertEquals(cache.pop("key"), None) def test_del_multi(self): cache = LruCache(4, 2, cache_type=TreeCache) -- cgit 1.5.1 From 8c6012a4af4973b0a53af65a31cbdb92a3dec5a2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 25 Jan 2016 13:12:35 +0000 Subject: Fix tests --- synapse/api/filtering.py | 2 +- tests/api/test_filtering.py | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 116060ee7f..6c13ada5df 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -193,7 +193,7 @@ class Filter(object): sender = event.get("sender", None) if not sender: # Presence events have their 'sender' in content.user_id - sender = event.get("conntent", {}).get("user_id", None) + sender = event.get("content", {}).get("user_id", None) return self.check_fields( event.get("room_id", None), diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 16ee6bbe6a..1a4e439d30 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -13,26 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from tests import unittest from twisted.internet import defer -from mock import Mock, NonCallableMock +from mock import Mock from tests.utils import ( MockHttpResource, DeferredMockCallable, setup_test_homeserver ) from synapse.types import UserID -from synapse.api.filtering import FilterCollection, Filter +from synapse.api.filtering import Filter +from synapse.events import FrozenEvent user_localpart = "test_user" # MockEvent = namedtuple("MockEvent", "sender type room_id") def MockEvent(**kwargs): - ev = NonCallableMock(spec_set=kwargs.keys()) - ev.configure_mock(**kwargs) - return ev + return FrozenEvent(kwargs) class FilteringTestCase(unittest.TestCase): -- cgit 1.5.1 From 9959d9ece84d85dae3ed06b22e3f234575b93fd1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 26 Jan 2016 13:52:29 +0000 Subject: Remove redundated BaseHomeServer --- synapse/app/homeserver.py | 137 +++++--------- synapse/federation/__init__.py | 9 +- synapse/federation/replication.py | 2 - synapse/federation/transport/__init__.py | 52 ------ synapse/federation/transport/client.py | 4 + synapse/federation/transport/server.py | 82 +++++---- synapse/server.py | 106 ++++++----- tests/federation/__init__.py | 0 tests/federation/test_federation.py | 303 ------------------------------- tests/rest/client/v1/test_presence.py | 18 +- tests/test_types.py | 5 +- tests/utils.py | 18 ++ 12 files changed, 191 insertions(+), 545 deletions(-) delete mode 100644 tests/federation/__init__.py delete mode 100644 tests/federation/test_federation.py (limited to 'tests') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 6928d9d3e4..795c655ae3 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -50,16 +50,14 @@ from twisted.cred import checkers, portal from twisted.internet import reactor, task, defer from twisted.application import service -from twisted.enterprise import adbapi from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.static import File from twisted.web.server import Site, GzipEncoderFactory, Request -from synapse.http.server import JsonResource, RootRedirect +from synapse.http.server import RootRedirect from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v2 import KeyApiV2Resource -from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.api.urls import ( FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX, @@ -69,6 +67,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.util.logcontext import LoggingContext from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.federation.transport.server import TransportLayerServer from synapse import events @@ -95,80 +94,37 @@ def gz_wrap(r): return EncodingResourceWrapper(r, [GzipEncoderFactory()]) -class SynapseHomeServer(HomeServer): - - def build_http_client(self): - return MatrixFederationHttpClient(self) - - def build_client_resource(self): - return ClientRestResource(self) - - def build_resource_for_federation(self): - return JsonResource(self) - - def build_resource_for_web_client(self): - webclient_path = self.get_config().web_client_location - if not webclient_path: - try: - import syweb - except ImportError: - quit_with_error( - "Could not find a webclient.\n\n" - "Please either install the matrix-angular-sdk or configure\n" - "the location of the source to serve via the configuration\n" - "option `web_client_location`\n\n" - "To install the `matrix-angular-sdk` via pip, run:\n\n" - " pip install '%(dep)s'\n" - "\n" - "You can also disable hosting of the webclient via the\n" - "configuration option `web_client`\n" - % {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]} - ) - syweb_path = os.path.dirname(syweb.__file__) - webclient_path = os.path.join(syweb_path, "webclient") - # GZip is disabled here due to - # https://twistedmatrix.com/trac/ticket/7678 - # (It can stay enabled for the API resources: they call - # write() with the whole body and then finish() straight - # after and so do not trigger the bug. - # GzipFile was removed in commit 184ba09 - # return GzipFile(webclient_path) # TODO configurable? - return File(webclient_path) # TODO configurable? - - def build_resource_for_static_content(self): - # This is old and should go away: not going to bother adding gzip - return File( - os.path.join(os.path.dirname(synapse.__file__), "static") - ) - - def build_resource_for_content_repo(self): - return ContentRepoResource( - self, self.config.uploads_path, self.auth, self.content_addr - ) - - def build_resource_for_media_repository(self): - return MediaRepositoryResource(self) - - def build_resource_for_server_key(self): - return LocalKey(self) - - def build_resource_for_server_key_v2(self): - return KeyApiV2Resource(self) - - def build_resource_for_metrics(self): - if self.get_config().enable_metrics: - return MetricsResource(self) - else: - return None - - def build_db_pool(self): - name = self.db_config["name"] +def build_resource_for_web_client(hs): + webclient_path = hs.get_config().web_client_location + if not webclient_path: + try: + import syweb + except ImportError: + quit_with_error( + "Could not find a webclient.\n\n" + "Please either install the matrix-angular-sdk or configure\n" + "the location of the source to serve via the configuration\n" + "option `web_client_location`\n\n" + "To install the `matrix-angular-sdk` via pip, run:\n\n" + " pip install '%(dep)s'\n" + "\n" + "You can also disable hosting of the webclient via the\n" + "configuration option `web_client`\n" + % {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]} + ) + syweb_path = os.path.dirname(syweb.__file__) + webclient_path = os.path.join(syweb_path, "webclient") + # GZip is disabled here due to + # https://twistedmatrix.com/trac/ticket/7678 + # (It can stay enabled for the API resources: they call + # write() with the whole body and then finish() straight + # after and so do not trigger the bug. + # GzipFile was removed in commit 184ba09 + # return GzipFile(webclient_path) # TODO configurable? + return File(webclient_path) # TODO configurable? - return adbapi.ConnectionPool( - name, - **self.db_config.get("args", {}) - ) +class SynapseHomeServer(HomeServer): def _listener_http(self, config, listener_config): port = listener_config["port"] bind_address = listener_config.get("bind_address", "") @@ -178,13 +134,11 @@ class SynapseHomeServer(HomeServer): if tls and config.no_tls: return - metrics_resource = self.get_resource_for_metrics() - resources = {} for res in listener_config["resources"]: for name in res["names"]: if name == "client": - client_resource = self.get_client_resource() + client_resource = ClientRestResource(self) if res["compress"]: client_resource = gz_wrap(client_resource) @@ -198,31 +152,35 @@ class SynapseHomeServer(HomeServer): if name == "federation": resources.update({ - FEDERATION_PREFIX: self.get_resource_for_federation(), + FEDERATION_PREFIX: TransportLayerServer(self), }) if name in ["static", "client"]: resources.update({ - STATIC_PREFIX: self.get_resource_for_static_content(), + STATIC_PREFIX: File( + os.path.join(os.path.dirname(synapse.__file__), "static") + ), }) if name in ["media", "federation", "client"]: resources.update({ - MEDIA_PREFIX: self.get_resource_for_media_repository(), - CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(), + MEDIA_PREFIX: MediaRepositoryResource(self), + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path, self.auth, self.content_addr + ), }) if name in ["keys", "federation"]: resources.update({ - SERVER_KEY_PREFIX: self.get_resource_for_server_key(), - SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(), + SERVER_KEY_PREFIX: LocalKey(self), + SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self), }) if name == "webclient": - resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client() + resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self) - if name == "metrics" and metrics_resource: - resources[METRICS_PREFIX] = metrics_resource + if name == "metrics" and self.get_config().enable_metrics: + resources[METRICS_PREFIX] = MetricsResource(self) root_resource = create_resource_tree(resources) if tls: @@ -675,7 +633,7 @@ def _resource_id(resource, path_seg): the mapping should looks like _resource_id(A,C) = B. Args: - resource (Resource): The *parent* Resource + resource (Resource): The *parent* Resourceb path_seg (str): The name of the child Resource to be attached. Returns: str: A unique string which can be a key to the child Resource. @@ -684,7 +642,7 @@ def _resource_id(resource, path_seg): def run(hs): - PROFILE_SYNAPSE = False + PROFILE_SYNAPSE = True if PROFILE_SYNAPSE: def profile(func): from cProfile import Profile @@ -761,6 +719,7 @@ def run(hs): auto_close_fds=False, verbose=True, logger=logger, + chdir=os.path.dirname(os.path.abspath(__file__)), ) daemon.start() diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py index 0bfb79d09f..979fdf2431 100644 --- a/synapse/federation/__init__.py +++ b/synapse/federation/__init__.py @@ -17,15 +17,10 @@ """ from .replication import ReplicationLayer -from .transport import TransportLayer +from .transport.client import TransportLayerClient def initialize_http_replication(homeserver): - transport = TransportLayer( - homeserver, - homeserver.hostname, - server=homeserver.get_resource_for_federation(), - client=homeserver.get_http_client() - ) + transport = TransportLayerClient(homeserver) return ReplicationLayer(homeserver, transport) diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index 6e0be8ef15..3e062a5eab 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -54,8 +54,6 @@ class ReplicationLayer(FederationClient, FederationServer): self.keyring = hs.get_keyring() self.transport_layer = transport_layer - self.transport_layer.register_received_handler(self) - self.transport_layer.register_request_handler(self) self.federation_client = self diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py index 155a7d5870..d9fcc520a0 100644 --- a/synapse/federation/transport/__init__.py +++ b/synapse/federation/transport/__init__.py @@ -20,55 +20,3 @@ By default this is done over HTTPS (and all home servers are required to support HTTPS), however individual pairings of servers may decide to communicate over a different (albeit still reliable) protocol. """ - -from .server import TransportLayerServer -from .client import TransportLayerClient - -from synapse.util.ratelimitutils import FederationRateLimiter - - -class TransportLayer(TransportLayerServer, TransportLayerClient): - """This is a basic implementation of the transport layer that translates - transactions and other requests to/from HTTP. - - Attributes: - server_name (str): Local home server host - - server (synapse.http.server.HttpServer): the http server to - register listeners on - - client (synapse.http.client.HttpClient): the http client used to - send requests - - request_handler (TransportRequestHandler): The handler to fire when we - receive requests for data. - - received_handler (TransportReceivedHandler): The handler to fire when - we receive data. - """ - - def __init__(self, homeserver, server_name, server, client): - """ - Args: - server_name (str): Local home server host - server (synapse.protocol.http.HttpServer): the http server to - register listeners on - client (synapse.protocol.http.HttpClient): the http client used to - send requests - """ - self.keyring = homeserver.get_keyring() - self.clock = homeserver.get_clock() - self.server_name = server_name - self.server = server - self.client = client - self.request_handler = None - self.received_handler = None - - self.ratelimiter = FederationRateLimiter( - self.clock, - window_size=homeserver.config.federation_rc_window_size, - sleep_limit=homeserver.config.federation_rc_sleep_limit, - sleep_msec=homeserver.config.federation_rc_sleep_delay, - reject_limit=homeserver.config.federation_rc_reject_limit, - concurrent_requests=homeserver.config.federation_rc_concurrent, - ) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 949d01dea8..2b5d40ea7f 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -28,6 +28,10 @@ logger = logging.getLogger(__name__) class TransportLayerClient(object): """Sends federation HTTP requests to other servers""" + def __init__(self, hs): + self.server_name = hs.hostname + self.client = hs.get_http_client() + @log_function def get_room_state(self, destination, room_id, event_id): """ Requests all state for a given room from the given server at the diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 8dca0a7f6b..65e054f7dd 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -17,7 +17,8 @@ from twisted.internet import defer from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.errors import Codes, SynapseError -from synapse.util.logutils import log_function +from synapse.http.server import JsonResource +from synapse.util.ratelimitutils import FederationRateLimiter import functools import logging @@ -28,9 +29,41 @@ import re logger = logging.getLogger(__name__) -class TransportLayerServer(object): +class TransportLayerServer(JsonResource): """Handles incoming federation HTTP requests""" + def __init__(self, hs): + self.hs = hs + self.clock = hs.get_clock() + + super(TransportLayerServer, self).__init__(hs) + + self.authenticator = Authenticator(hs) + self.ratelimiter = FederationRateLimiter( + self.clock, + window_size=hs.config.federation_rc_window_size, + sleep_limit=hs.config.federation_rc_sleep_limit, + sleep_msec=hs.config.federation_rc_sleep_delay, + reject_limit=hs.config.federation_rc_reject_limit, + concurrent_requests=hs.config.federation_rc_concurrent, + ) + + self.register_servlets() + + def register_servlets(self): + register_servlets( + self.hs, + resource=self, + ratelimiter=self.ratelimiter, + authenticator=self.authenticator, + ) + + +class Authenticator(object): + def __init__(self, hs): + self.keyring = hs.get_keyring() + self.server_name = hs.hostname + # A method just so we can pass 'self' as the authenticator to the Servlets @defer.inlineCallbacks def authenticate_request(self, request): @@ -98,37 +131,9 @@ class TransportLayerServer(object): defer.returnValue((origin, content)) - @log_function - def register_received_handler(self, handler): - """ Register a handler that will be fired when we receive data. - - Args: - handler (TransportReceivedHandler) - """ - FederationSendServlet( - handler, - authenticator=self, - ratelimiter=self.ratelimiter, - server_name=self.server_name, - ).register(self.server) - - @log_function - def register_request_handler(self, handler): - """ Register a handler that will be fired when we get asked for data. - - Args: - handler (TransportRequestHandler) - """ - for servletclass in SERVLET_CLASSES: - servletclass( - handler, - authenticator=self, - ratelimiter=self.ratelimiter, - ).register(self.server) - class BaseFederationServlet(object): - def __init__(self, handler, authenticator, ratelimiter): + def __init__(self, handler, authenticator, ratelimiter, server_name): self.handler = handler self.authenticator = authenticator self.ratelimiter = ratelimiter @@ -172,7 +177,9 @@ class FederationSendServlet(BaseFederationServlet): PATH = "/send/([^/]*)/" def __init__(self, handler, server_name, **kwargs): - super(FederationSendServlet, self).__init__(handler, **kwargs) + super(FederationSendServlet, self).__init__( + handler, server_name=server_name, **kwargs + ) self.server_name = server_name # This is when someone is trying to send us a bunch of data. @@ -432,6 +439,7 @@ class On3pidBindServlet(BaseFederationServlet): SERVLET_CLASSES = ( + FederationSendServlet, FederationPullServlet, FederationEventServlet, FederationStateServlet, @@ -451,3 +459,13 @@ SERVLET_CLASSES = ( FederationThirdPartyInviteExchangeServlet, On3pidBindServlet, ) + + +def register_servlets(hs, resource, authenticator, ratelimiter): + for servletclass in SERVLET_CLASSES: + servletclass( + handler=hs.get_replication_layer(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) diff --git a/synapse/server.py b/synapse/server.py index 4a5796b982..a59e46ca2d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -20,6 +20,8 @@ # Imports required for the default HomeServer() implementation from twisted.web.client import BrowserLikePolicyForHTTPS +from twisted.enterprise import adbapi + from synapse.federation import initialize_http_replication from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.notifier import Notifier @@ -36,8 +38,10 @@ from synapse.push.pusherpool import PusherPool from synapse.events.builder import EventBuilderFactory from synapse.api.filtering import Filtering +from synapse.http.matrixfederationclient import MatrixFederationHttpClient + -class BaseHomeServer(object): +class HomeServer(object): """A basic homeserver object without lazy component builders. This will need all of the components it requires to either be passed as @@ -102,36 +106,6 @@ class BaseHomeServer(object): for depname in kwargs: setattr(self, depname, kwargs[depname]) - @classmethod - def _make_dependency_method(cls, depname): - def _get(self): - if hasattr(self, depname): - return getattr(self, depname) - - if hasattr(self, "build_%s" % (depname)): - # Prevent cyclic dependencies from deadlocking - if depname in self._building: - raise ValueError("Cyclic dependency while building %s" % ( - depname, - )) - self._building[depname] = 1 - - builder = getattr(self, "build_%s" % (depname)) - dep = builder() - setattr(self, depname, dep) - - del self._building[depname] - - return dep - - raise NotImplementedError( - "%s has no %s nor a builder for it" % ( - type(self).__name__, depname, - ) - ) - - setattr(BaseHomeServer, "get_%s" % (depname), _get) - def get_ip_from_request(self, request): # X-Forwarded-For is handled by our custom request type. return request.getClientIP() @@ -142,24 +116,6 @@ class BaseHomeServer(object): def is_mine_id(self, string): return string.split(":", 1)[1] == self.hostname -# Build magic accessors for every dependency -for depname in BaseHomeServer.DEPENDENCIES: - BaseHomeServer._make_dependency_method(depname) - - -class HomeServer(BaseHomeServer): - """A homeserver object that will construct most of its dependencies as - required. - - It still requires the following to be specified by the caller: - resource_for_client - resource_for_web_client - resource_for_federation - resource_for_content_repo - http_client - db_pool - """ - def build_clock(self): return Clock() @@ -224,3 +180,55 @@ class HomeServer(BaseHomeServer): def build_pusherpool(self): return PusherPool(self) + + def build_http_client(self): + return MatrixFederationHttpClient(self) + + def build_db_pool(self): + name = self.db_config["name"] + + return adbapi.ConnectionPool( + name, + **self.db_config.get("args", {}) + ) + + +def _make_dependency_method(depname): + def _get(hs): + try: + return getattr(hs, depname) + except AttributeError: + pass + + try: + builder = getattr(hs, "build_%s" % (depname)) + except AttributeError: + builder = None + + if builder: + # Prevent cyclic dependencies from deadlocking + if depname in hs._building: + raise ValueError("Cyclic dependency while building %s" % ( + depname, + )) + hs._building[depname] = 1 + + dep = builder() + setattr(hs, depname, dep) + + del hs._building[depname] + + return dep + + raise NotImplementedError( + "%s has no %s nor a builder for it" % ( + type(hs).__name__, depname, + ) + ) + + setattr(HomeServer, "get_%s" % (depname), _get) + + +# Build magic accessors for every dependency +for depname in HomeServer.DEPENDENCIES: + _make_dependency_method(depname) diff --git a/tests/federation/__init__.py b/tests/federation/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/federation/test_federation.py b/tests/federation/test_federation.py deleted file mode 100644 index f2c2ee4127..0000000000 --- a/tests/federation/test_federation.py +++ /dev/null @@ -1,303 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# trial imports -from twisted.internet import defer -from tests import unittest - -# python imports -from mock import Mock, ANY - -from ..utils import MockHttpResource, MockClock, setup_test_homeserver - -from synapse.federation import initialize_http_replication -from synapse.events import FrozenEvent - - -def make_pdu(prev_pdus=[], **kwargs): - """Provide some default fields for making a PduTuple.""" - pdu_fields = { - "state_key": None, - "prev_events": prev_pdus, - } - pdu_fields.update(kwargs) - - return FrozenEvent(pdu_fields) - - -class FederationTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource() - self.mock_http_client = Mock(spec=[ - "get_json", - "put_json", - ]) - self.mock_persistence = Mock(spec=[ - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_retry_timings", - "get_auth_chain", - ]) - self.mock_persistence.get_received_txn_response.return_value = ( - defer.succeed(None) - ) - - retry_timings_res = { - "destination": "", - "retry_last_ts": 0, - "retry_interval": 0, - } - self.mock_persistence.get_destination_retry_timings.return_value = ( - defer.succeed(retry_timings_res) - ) - self.mock_persistence.get_auth_chain.return_value = [] - self.clock = MockClock() - hs = yield setup_test_homeserver( - resource_for_federation=self.mock_resource, - http_client=self.mock_http_client, - datastore=self.mock_persistence, - clock=self.clock, - keyring=Mock(), - ) - self.federation = initialize_http_replication(hs) - self.distributor = hs.get_distributor() - - @defer.inlineCallbacks - def test_get_state(self): - mock_handler = Mock(spec=[ - "get_state_for_pdu", - ]) - - self.federation.set_handler(mock_handler) - - mock_handler.get_state_for_pdu.return_value = defer.succeed([]) - - # Empty context initially - (code, response) = yield self.mock_resource.trigger( - "GET", - "/_matrix/federation/v1/state/my-context/", - None - ) - self.assertEquals(200, code) - self.assertFalse(response["pdus"]) - - # Now lets give the context some state - mock_handler.get_state_for_pdu.return_value = ( - defer.succeed([ - make_pdu( - event_id="the-pdu-id", - origin="red", - user_id="@a:red", - room_id="my-context", - type="m.topic", - origin_server_ts=123456789000, - depth=1, - content={"topic": "The topic"}, - state_key="", - power_level=1000, - prev_state="last-pdu-id", - ), - ]) - ) - - (code, response) = yield self.mock_resource.trigger( - "GET", - "/_matrix/federation/v1/state/my-context/", - None - ) - self.assertEquals(200, code) - self.assertEquals(1, len(response["pdus"])) - - @defer.inlineCallbacks - def test_get_pdu(self): - mock_handler = Mock(spec=[ - "get_persisted_pdu", - ]) - - self.federation.set_handler(mock_handler) - - mock_handler.get_persisted_pdu.return_value = ( - defer.succeed(None) - ) - - (code, response) = yield self.mock_resource.trigger( - "GET", - "/_matrix/federation/v1/event/abc123def456/", - None - ) - self.assertEquals(404, code) - - # Now insert such a PDU - mock_handler.get_persisted_pdu.return_value = ( - defer.succeed( - make_pdu( - event_id="abc123def456", - origin="red", - user_id="@a:red", - room_id="my-context", - type="m.text", - origin_server_ts=123456789001, - depth=1, - content={"text": "Here is the message"}, - ) - ) - ) - - (code, response) = yield self.mock_resource.trigger( - "GET", - "/_matrix/federation/v1/event/abc123def456/", - None - ) - self.assertEquals(200, code) - self.assertEquals(1, len(response["pdus"])) - self.assertEquals("m.text", response["pdus"][0]["type"]) - - @defer.inlineCallbacks - def test_send_pdu(self): - self.mock_http_client.put_json.return_value = defer.succeed( - (200, "OK") - ) - - pdu = make_pdu( - event_id="abc123def456", - origin="red", - user_id="@a:red", - room_id="my-context", - type="m.text", - origin_server_ts=123456789001, - depth=1, - content={"text": "Here is the message"}, - ) - - yield self.federation.send_pdu(pdu, ["remote"]) - - self.mock_http_client.put_json.assert_called_with( - "remote", - path="/_matrix/federation/v1/send/1000000/", - data={ - "origin_server_ts": 1000000, - "origin": "test", - "pdus": [ - pdu.get_pdu_json(), - ], - 'pdu_failures': [], - }, - json_data_callback=ANY, - long_retries=True, - ) - - @defer.inlineCallbacks - def test_send_edu(self): - self.mock_http_client.put_json.return_value = defer.succeed( - (200, "OK") - ) - - yield self.federation.send_edu( - destination="remote", - edu_type="m.test", - content={"testing": "content here"}, - ) - - # MockClock ensures we can guess these timestamps - self.mock_http_client.put_json.assert_called_with( - "remote", - path="/_matrix/federation/v1/send/1000000/", - data={ - "origin": "test", - "origin_server_ts": 1000000, - "pdus": [], - "edus": [ - { - "edu_type": "m.test", - "content": {"testing": "content here"}, - } - ], - 'pdu_failures': [], - }, - json_data_callback=ANY, - long_retries=True, - ) - - @defer.inlineCallbacks - def test_recv_edu(self): - recv_observer = Mock() - recv_observer.return_value = defer.succeed(()) - - self.federation.register_edu_handler("m.test", recv_observer) - - yield self.mock_resource.trigger( - "PUT", - "/_matrix/federation/v1/send/1001000/", - """{ - "origin": "remote", - "origin_server_ts": 1001000, - "pdus": [], - "edus": [ - { - "origin": "remote", - "destination": "test", - "edu_type": "m.test", - "content": {"testing": "reply here"} - } - ] - }""" - ) - - recv_observer.assert_called_with( - "remote", {"testing": "reply here"} - ) - - @defer.inlineCallbacks - def test_send_query(self): - self.mock_http_client.get_json.return_value = defer.succeed( - {"your": "response"} - ) - - response = yield self.federation.make_query( - destination="remote", - query_type="a-question", - args={"one": "1", "two": "2"}, - ) - - self.assertEquals({"your": "response"}, response) - - self.mock_http_client.get_json.assert_called_with( - destination="remote", - path="/_matrix/federation/v1/query/a-question", - args={"one": "1", "two": "2"}, - retry_on_dns_fail=True, - ) - - @defer.inlineCallbacks - def test_recv_query(self): - recv_handler = Mock() - recv_handler.return_value = defer.succeed({"another": "response"}) - - self.federation.register_query_handler("a-question", recv_handler) - - code, response = yield self.mock_resource.trigger( - "GET", - "/_matrix/federation/v1/query/a-question?three=3&four=4", - None - ) - - self.assertEquals(200, code) - self.assertEquals({"another": "response"}, response) - - recv_handler.assert_called_with( - {"three": "3", "four": "4"} - ) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 90b911f879..8d7cfd79ab 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -280,6 +280,15 @@ class PresenceEventStreamTestCase(unittest.TestCase): } EventSources.SOURCE_TYPES["presence"] = PresenceEventSource + clock = Mock(spec=[ + "call_later", + "cancel_call_later", + "time_msec", + "looping_call", + ]) + + clock.time_msec.return_value = 1000000 + hs = yield setup_test_homeserver( http_client=None, resource_for_client=self.mock_resource, @@ -289,16 +298,9 @@ class PresenceEventStreamTestCase(unittest.TestCase): "get_presence_list", "get_rooms_for_user", ]), - clock=Mock(spec=[ - "call_later", - "cancel_call_later", - "time_msec", - "looping_call", - ]), + clock=clock, ) - hs.get_clock().time_msec.return_value = 1000000 - def _get_user_by_req(req=None, allow_guest=False): return Requester(UserID.from_string(myid), "", False) diff --git a/tests/test_types.py b/tests/test_types.py index b9534329e6..24d61dbe54 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -16,10 +16,10 @@ from tests import unittest from synapse.api.errors import SynapseError -from synapse.server import BaseHomeServer +from synapse.server import HomeServer from synapse.types import UserID, RoomAlias -mock_homeserver = BaseHomeServer(hostname="my.domain") +mock_homeserver = HomeServer(hostname="my.domain") class UserIDTestCase(unittest.TestCase): @@ -34,7 +34,6 @@ class UserIDTestCase(unittest.TestCase): with self.assertRaises(SynapseError): UserID.from_string("") - def test_build(self): user = UserID("5678efgh", "my.domain") diff --git a/tests/utils.py b/tests/utils.py index 358b5b72b7..d75d492cb5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -19,6 +19,8 @@ from synapse.api.constants import EventTypes from synapse.storage.prepare_database import prepare_database from synapse.storage.engines import create_engine from synapse.server import HomeServer +from synapse.federation.transport import server +from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.logcontext import LoggingContext @@ -80,6 +82,22 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers) + fed = kargs.get("resource_for_federation", None) + if fed: + server.register_servlets( + hs, + resource=fed, + authenticator=server.Authenticator(hs), + ratelimiter=FederationRateLimiter( + hs.get_clock(), + window_size=hs.config.federation_rc_window_size, + sleep_limit=hs.config.federation_rc_sleep_limit, + sleep_msec=hs.config.federation_rc_sleep_delay, + reject_limit=hs.config.federation_rc_reject_limit, + concurrent_requests=hs.config.federation_rc_concurrent + ), + ) + defer.returnValue(hs) -- cgit 1.5.1 From 0487c9441f1439bd02cb4d107a4fcacfe5dbe75d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 27 Jan 2016 17:25:07 +0000 Subject: Fix tests --- tests/storage/test_appservice.py | 6 +++--- tests/storage/test_registration.py | 3 +-- tests/utils.py | 8 ++++++++ 3 files changed, 12 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 5abecdf6e0..ed8af10d87 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -439,7 +439,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f2 = self._write_config(suffix="2") config = Mock(app_service_config_files=[f1, f2]) - hs = yield setup_test_homeserver(config=config) + hs = yield setup_test_homeserver(config=config, datastore=Mock()) ApplicationServiceStore(hs) @@ -449,7 +449,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f2 = self._write_config(id="id", suffix="2") config = Mock(app_service_config_files=[f1, f2]) - hs = yield setup_test_homeserver(config=config) + hs = yield setup_test_homeserver(config=config, datastore=Mock()) with self.assertRaises(ConfigError) as cm: ApplicationServiceStore(hs) @@ -465,7 +465,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f2 = self._write_config(as_token="as_token", suffix="2") config = Mock(app_service_config_files=[f1, f2]) - hs = yield setup_test_homeserver(config=config) + hs = yield setup_test_homeserver(config=config, datastore=Mock()) with self.assertRaises(ConfigError) as cm: ApplicationServiceStore(hs) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index a35efcc71e..7b3b4c13bc 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -18,7 +18,6 @@ from tests import unittest from twisted.internet import defer from synapse.api.errors import StoreError -from synapse.storage.registration import RegistrationStore from synapse.util import stringutils from tests.utils import setup_test_homeserver @@ -31,7 +30,7 @@ class RegistrationStoreTestCase(unittest.TestCase): hs = yield setup_test_homeserver() self.db_pool = hs.get_db_pool() - self.store = RegistrationStore(hs) + self.store = hs.get_datastore() self.user_id = "@my-user:test" self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz", diff --git a/tests/utils.py b/tests/utils.py index d75d492cb5..43cc2b30cd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -60,8 +60,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): name, db_pool=db_pool, config=config, version_string="Synapse/tests", database_engine=create_engine("sqlite3"), + get_db_conn=db_pool.get_db_conn, **kargs ) + hs.setup() else: hs = HomeServer( name, db_pool=None, datastore=datastore, config=config, @@ -280,6 +282,12 @@ class SQLiteMemoryDbPool(ConnectionPool, object): lambda conn: prepare_database(conn, engine) ) + def get_db_conn(self): + conn = self.connect() + engine = create_engine("sqlite3") + prepare_database(conn, engine) + return conn + class MemoryDataStore(object): -- cgit 1.5.1 From 5cba88ea7c96e5e8a9f3bc1a28cf3414b3083d60 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 27 Jan 2016 17:42:45 +0000 Subject: Make it possible to paginate forwards from stream tokens In order that we can fill the gap after a /sync, make it possible to paginate forwards from a stream token. --- synapse/handlers/message.py | 43 +++++++++++++++++++------------------- tests/rest/client/v1/test_rooms.py | 16 ++++++++++++-- 2 files changed, 35 insertions(+), 24 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ff800f8af1..b73ad62147 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -105,8 +105,6 @@ class MessageHandler(BaseHandler): room_token = pagin_config.from_token.room_key room_token = RoomStreamToken.parse(room_token) - if room_token.topological is None: - raise SynapseError(400, "Invalid token") pagin_config.from_token = pagin_config.from_token.copy_and_replace( "room_key", str(room_token) @@ -117,27 +115,28 @@ class MessageHandler(BaseHandler): membership, member_event_id = yield self._check_in_room_or_world_readable( room_id, user_id ) - if membership == Membership.LEAVE: - # If they have left the room then clamp the token to be before - # they left the room. - leave_token = yield self.store.get_topological_token_for_event( - member_event_id + + if source_config.direction == 'b': + # if we're going backwards, we might need to backfill. This + # requires that we have a topo token. + if room_token.topological is None: + raise SynapseError(400, "Invalid token: cannot paginate " + "backwards from a stream token") + + if membership == Membership.LEAVE: + # If they have left the room then clamp the token to be before + # they left the room, to save the effort of loading from the + # database. + leave_token = yield self.store.get_topological_token_for_event( + member_event_id + ) + leave_token = RoomStreamToken.parse(leave_token) + if leave_token.topological < room_token.topological: + source_config.from_key = str(leave_token) + + yield self.hs.get_handlers().federation_handler.maybe_backfill( + room_id, room_token.topological ) - leave_token = RoomStreamToken.parse(leave_token) - if leave_token.topological < room_token.topological: - source_config.from_key = str(leave_token) - - if source_config.direction == "f": - if source_config.to_key is None: - source_config.to_key = str(leave_token) - else: - to_token = RoomStreamToken.parse(source_config.to_key) - if leave_token.topological < to_token.topological: - source_config.to_key = str(leave_token) - - yield self.hs.get_handlers().federation_handler.maybe_backfill( - room_id, room_token.topological - ) events, next_key = yield data_source.get_pagination_rows( requester.user, source_config, room_id diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index cd03106e88..2fe6f695f5 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1045,8 +1045,20 @@ class RoomMessageListTestCase(RestTestCase): self.assertTrue("end" in response) @defer.inlineCallbacks - def test_stream_token_is_rejected(self): + def test_stream_token_is_rejected_for_back_pagination(self): (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/messages?access_token=x&from=s0_0_0_0" % + "/rooms/%s/messages?access_token=x&from=s0_0_0_0_0&dir=b" % self.room_id) self.assertEquals(400, code) + + @defer.inlineCallbacks + def test_stream_token_is_accepted_for_fwd_pagianation(self): + token = "s0_0_0_0_0" + (code, response) = yield self.mock_resource.trigger_get( + "/rooms/%s/messages?access_token=x&from=%s" % + (self.room_id, token)) + self.assertEquals(200, code) + self.assertTrue("start" in response) + self.assertEquals(token, response['start']) + self.assertTrue("chunk" in response) + self.assertTrue("end" in response) \ No newline at end of file -- cgit 1.5.1 From 4e7948b47a3f197682de82fc0cda07ebb08a581d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2016 11:52:34 +0000 Subject: Allow paginating backwards from stream token --- synapse/handlers/message.py | 15 +++++++++------ synapse/storage/stream.py | 16 ++++++++++++++-- tests/rest/client/v1/test_rooms.py | 9 +-------- 3 files changed, 24 insertions(+), 16 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index b73ad62147..82c8cb5f0c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import SynapseError, AuthError, Codes +from synapse.api.errors import AuthError, Codes from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator @@ -119,9 +119,12 @@ class MessageHandler(BaseHandler): if source_config.direction == 'b': # if we're going backwards, we might need to backfill. This # requires that we have a topo token. - if room_token.topological is None: - raise SynapseError(400, "Invalid token: cannot paginate " - "backwards from a stream token") + if room_token.topological: + max_topo = room_token.topological + else: + max_topo = yield self.store.get_max_topological_token_for_stream_and_room( + room_id, room_token.stream + ) if membership == Membership.LEAVE: # If they have left the room then clamp the token to be before @@ -131,11 +134,11 @@ class MessageHandler(BaseHandler): member_event_id ) leave_token = RoomStreamToken.parse(leave_token) - if leave_token.topological < room_token.topological: + if leave_token.topological < max_topo: source_config.from_key = str(leave_token) yield self.hs.get_handlers().federation_handler.maybe_backfill( - room_id, room_token.topological + room_id, max_topo ) events, next_key = yield data_source.get_pagination_rows( diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 28721e6994..5096b46864 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -234,10 +234,10 @@ class StreamStore(SQLBaseStore): get_prev_content=True ) - ret.reverse() - self._set_before_and_after(ret, rows, topo_order=False) + ret.reverse() + if rows: key = "s%d" % min(r["stream_ordering"] for r in rows) else: @@ -570,6 +570,18 @@ class StreamStore(SQLBaseStore): row["topological_ordering"], row["stream_ordering"],) ) + def get_max_topological_token_for_stream_and_room(self, room_id, stream_key): + sql = ( + "SELECT max(topological_ordering) FROM events" + " WHERE room_id = ? AND stream_ordering < ?" + ) + return self._execute( + "get_max_topological_token_for_stream_and_room", None, + sql, room_id, stream_key, + ).addCallback( + lambda r: r[0][0] if r else 0 + ) + def _get_max_topological_txn(self, txn): txn.execute( "SELECT MAX(topological_ordering) FROM events" diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2fe6f695f5..ad5dd3bd6e 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1044,13 +1044,6 @@ class RoomMessageListTestCase(RestTestCase): self.assertTrue("chunk" in response) self.assertTrue("end" in response) - @defer.inlineCallbacks - def test_stream_token_is_rejected_for_back_pagination(self): - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/messages?access_token=x&from=s0_0_0_0_0&dir=b" % - self.room_id) - self.assertEquals(400, code) - @defer.inlineCallbacks def test_stream_token_is_accepted_for_fwd_pagianation(self): token = "s0_0_0_0_0" @@ -1061,4 +1054,4 @@ class RoomMessageListTestCase(RestTestCase): self.assertTrue("start" in response) self.assertEquals(token, response['start']) self.assertTrue("chunk" in response) - self.assertTrue("end" in response) \ No newline at end of file + self.assertTrue("end" in response) -- cgit 1.5.1 From 35981c8b71a2ce675f3b8414ca0e7920a5d1658e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2016 17:19:51 +0000 Subject: Fix test --- synapse/api/filtering.py | 5 +++++ tests/api/test_filtering.py | 7 ++++--- 2 files changed, 9 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 6c13ada5df..6eff83e5f8 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -15,6 +15,8 @@ from synapse.api.errors import SynapseError from synapse.types import UserID, RoomID +import ujson as json + class Filtering(object): @@ -149,6 +151,9 @@ class FilterCollection(object): "include_leave", False ) + def __repr__(self): + return "" % (json.dumps(self._filter_json),) + def get_filter_json(self): return self._filter_json diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 1a4e439d30..ceb0089268 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -382,19 +382,20 @@ class FilteringTestCase(unittest.TestCase): "types": ["m.*"] } } - user = UserID.from_string("@" + user_localpart + ":test") + filter_id = yield self.datastore.add_user_filter( - user_localpart=user_localpart, + user_localpart=user_localpart + "2", user_filter=user_filter_json, ) event = MockEvent( + event_id="$asdasd:localhost", sender="@foo:bar", type="custom.avatar.3d.crazy", ) events = [event] user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, + user_localpart=user_localpart + "2", filter_id=filter_id, ) -- cgit 1.5.1 From 4fce59f2747d9c73784470e84396a09ae70bbda7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Jan 2016 11:33:11 +0000 Subject: Add tests --- tests/util/test_lrucache.py | 7 +++++++ tests/util/test_treecache.py | 12 ++++++++++++ 2 files changed, 19 insertions(+) (limited to 'tests') diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 2cd3d26454..bab366fb7f 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -19,6 +19,7 @@ from .. import unittest from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache + class LruCacheTestCase(unittest.TestCase): def test_get_set(self): @@ -72,3 +73,9 @@ class LruCacheTestCase(unittest.TestCase): self.assertEquals(cache.get(("vehicles", "car")), "vroom") self.assertEquals(cache.get(("vehicles", "train")), "chuff") # Man from del_multi say "Yes". + + def test_clear(self): + cache = LruCache(1) + cache["key"] = 1 + cache.clear() + self.assertEquals(len(cache), 0) diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py index 9946ceb3f1..1efbeb6b33 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py @@ -25,6 +25,7 @@ class TreeCacheTestCase(unittest.TestCase): cache[("b",)] = "B" self.assertEquals(cache.get(("a",)), "A") self.assertEquals(cache.get(("b",)), "B") + self.assertEquals(len(cache), 2) def test_pop_onelevel(self): cache = TreeCache() @@ -33,6 +34,7 @@ class TreeCacheTestCase(unittest.TestCase): self.assertEquals(cache.pop(("a",)), "A") self.assertEquals(cache.pop(("a",)), None) self.assertEquals(cache.get(("b",)), "B") + self.assertEquals(len(cache), 1) def test_get_set_twolevel(self): cache = TreeCache() @@ -42,6 +44,7 @@ class TreeCacheTestCase(unittest.TestCase): self.assertEquals(cache.get(("a", "a")), "AA") self.assertEquals(cache.get(("a", "b")), "AB") self.assertEquals(cache.get(("b", "a")), "BA") + self.assertEquals(len(cache), 3) def test_pop_twolevel(self): cache = TreeCache() @@ -53,6 +56,7 @@ class TreeCacheTestCase(unittest.TestCase): self.assertEquals(cache.get(("a", "b")), "AB") self.assertEquals(cache.pop(("b", "a")), "BA") self.assertEquals(cache.pop(("b", "a")), None) + self.assertEquals(len(cache), 1) def test_pop_mixedlevel(self): cache = TreeCache() @@ -64,3 +68,11 @@ class TreeCacheTestCase(unittest.TestCase): self.assertEquals(cache.get(("a", "a")), None) self.assertEquals(cache.get(("a", "b")), None) self.assertEquals(cache.get(("b", "a")), "BA") + self.assertEquals(len(cache), 1) + + def test_clear(self): + cache = TreeCache() + cache[("a",)] = "A" + cache[("b",)] = "B" + cache.clear() + self.assertEquals(len(cache), 0) -- cgit 1.5.1 From f2d5ff5bf2cb95eb0a6619ae7fb40603175c8a7d Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 29 Jan 2016 14:53:14 +0000 Subject: Fix the mock homserver used in the tests --- tests/utils.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/utils.py b/tests/utils.py index 43cc2b30cd..431252a6f1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -49,6 +49,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.disable_registration = False config.macaroon_secret_key = "not even a little secret" config.server_name = "server.under.test" + config.trusted_third_party_id_servers = [] if "clock" not in kargs: kargs["clock"] = MockClock() -- cgit 1.5.1 From f8aae79a72e462f4af65a22d0665192867522174 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 3 Feb 2016 13:23:32 +0000 Subject: Simplify get_rooms --- synapse/app/homeserver.py | 4 +-- synapse/storage/room.py | 84 ++++------------------------------------------ tests/storage/test_room.py | 26 -------------- 3 files changed, 9 insertions(+), 105 deletions(-) (limited to 'tests') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index c3066d6a0d..0a6a19033d 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -674,8 +674,8 @@ def run(hs): stats["uptime_seconds"] = uptime stats["total_users"] = yield hs.get_datastore().count_all_users() - all_rooms = yield hs.get_datastore().get_rooms(False) - stats["total_room_count"] = len(all_rooms) + room_count = yield hs.get_datastore().get_room_count() + stats["total_room_count"] = room_count stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() daily_messages = yield hs.get_datastore().count_daily_messages() diff --git a/synapse/storage/room.py b/synapse/storage/room.py index dc09a3aaba..46ab38a313 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -87,90 +87,20 @@ class RoomStore(SQLBaseStore): desc="get_public_room_ids", ) - @defer.inlineCallbacks - def get_rooms(self, is_public): - """Retrieve a list of all public rooms. - - Args: - is_public (bool): True if the rooms returned should be public. - Returns: - A list of room dicts containing at least a "room_id" key, a - "topic" key if one is set, and a "name" key if one is set + def get_room_count(self): + """Retrieve a list of all rooms """ def f(txn): - def subquery(table_name, column_name=None): - column_name = column_name or table_name - return ( - "SELECT %(table_name)s.event_id as event_id, " - "%(table_name)s.room_id as room_id, %(column_name)s " - "FROM %(table_name)s " - "INNER JOIN current_state_events as c " - "ON c.event_id = %(table_name)s.event_id " % { - "column_name": column_name, - "table_name": table_name, - } - ) - - sql = ( - "SELECT" - " r.room_id," - " max(n.name)," - " max(t.topic)," - " max(v.history_visibility)," - " max(g.guest_access)" - " FROM rooms AS r" - " LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id" - " LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id" - " LEFT JOIN (%(history_visibility)s) AS v ON v.room_id = r.room_id" - " LEFT JOIN (%(guest_access)s) AS g ON g.room_id = r.room_id" - " WHERE r.is_public = ?" - " GROUP BY r.room_id" % { - "topic": subquery("topics", "topic"), - "name": subquery("room_names", "name"), - "history_visibility": subquery("history_visibility"), - "guest_access": subquery("guest_access"), - } - ) - - txn.execute(sql, (is_public,)) - - rows = txn.fetchall() - - for i, row in enumerate(rows): - room_id = row[0] - aliases = self._simple_select_onecol_txn( - txn, - table="room_aliases", - keyvalues={ - "room_id": room_id - }, - retcol="room_alias", - ) + sql = "SELECT count(*) FROM rooms" + txn.execute(sql) + row = txn.fetchone() + return row[0] or 0 - rows[i] = list(row) + [aliases] - - return rows - - rows = yield self.runInteraction( + return self.runInteraction( "get_rooms", f ) - ret = [ - { - "room_id": r[0], - "name": r[1], - "topic": r[2], - "world_readable": r[3] == "world_readable", - "guest_can_join": r[4] == "can_join", - "aliases": r[5], - } - for r in rows - if r[5] # We only return rooms that have at least one alias. - ] - - defer.returnValue(ret) - def _store_room_topic_txn(self, txn, event): if hasattr(event, "content") and "topic" in event.content: self._simple_insert_txn( diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 7fdbfc60f1..0baaf3df21 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -51,32 +51,6 @@ class RoomStoreTestCase(unittest.TestCase): (yield self.store.get_room(self.room.to_string())) ) - @defer.inlineCallbacks - def test_get_rooms(self): - # get_rooms does an INNER JOIN on the room_aliases table :( - - rooms = yield self.store.get_rooms(is_public=True) - # Should be empty before we add the alias - self.assertEquals([], rooms) - - yield self.store.create_room_alias_association( - room_alias=self.alias, - room_id=self.room.to_string(), - servers=["test"] - ) - - rooms = yield self.store.get_rooms(is_public=True) - - self.assertEquals(1, len(rooms)) - self.assertEquals({ - "name": None, - "room_id": self.room.to_string(), - "topic": None, - "aliases": [self.alias.to_string()], - "world_readable": False, - "guest_can_join": False, - }, rooms[0]) - class RoomEventsStoreTestCase(unittest.TestCase): -- cgit 1.5.1 From 5054806ec1f64fd784d9e74d73a678643d539c3f Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Wed, 3 Feb 2016 14:42:01 +0000 Subject: Rename config field to reflect yaml name --- synapse/config/registration.py | 6 +++--- synapse/rest/client/v1/register.py | 4 ++-- synapse/rest/client/v2_alpha/register.py | 2 +- tests/rest/client/v1/test_events.py | 2 +- tests/rest/client/v2_alpha/test_register.py | 4 ++-- tests/utils.py | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 76d2d2d640..90ea19bd4b 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -23,11 +23,11 @@ from distutils.util import strtobool class RegistrationConfig(Config): def read_config(self, config): - self.disable_registration = not bool( + self.enable_registration = bool( strtobool(str(config["enable_registration"])) ) if "disable_registration" in config: - self.disable_registration = bool( + self.enable_registration = not bool( strtobool(str(config["disable_registration"])) ) @@ -78,6 +78,6 @@ class RegistrationConfig(Config): def read_arguments(self, args): if args.enable_registration is not None: - self.disable_registration = not bool( + self.enable_registration = bool( strtobool(str(args.enable_registration)) ) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 2bfd4d96bf..6d6d03c34c 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -59,7 +59,7 @@ class RegisterRestServlet(ClientV1RestServlet): # } # TODO: persistent storage self.sessions = {} - self.disable_registration = hs.config.disable_registration + self.enable_registration = hs.config.enable_registration def on_GET(self, request): if self.hs.config.enable_registration_captcha: @@ -113,7 +113,7 @@ class RegisterRestServlet(ClientV1RestServlet): is_using_shared_secret = login_type == LoginType.SHARED_SECRET can_register = ( - not self.disable_registration + self.enable_registration or is_application_server or is_using_shared_secret ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 56a5bbec30..ec5c21fa1f 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -117,7 +117,7 @@ class RegisterRestServlet(RestServlet): return # == Normal User Registration == (everyone else) - if self.hs.config.disable_registration: + if not self.hs.config.enable_registration: raise SynapseError(403, "Registration has been disabled") guest_access_token = body.get("guest_access_token", None) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index b260e269ac..e9698bfdc9 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -122,7 +122,7 @@ class EventStreamPermissionsTestCase(RestTestCase): self.ratelimiter = hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) hs.config.enable_registration_captcha = False - hs.config.disable_registration = False + hs.config.enable_registration = True hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index f9a2b22485..df0841b0b1 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -41,7 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.hs.hostname = "superbig~testing~thing.com" self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) - self.hs.config.disable_registration = False + self.hs.config.enable_registration = True # init the thing we're testing self.servlet = RegisterRestServlet(self.hs) @@ -120,7 +120,7 @@ class RegisterRestServletTestCase(unittest.TestCase): })) def test_POST_disabled_registration(self): - self.hs.config.disable_registration = True + self.hs.config.enable_registration = False self.request_data = json.dumps({ "username": "kermit", "password": "monkey" diff --git a/tests/utils.py b/tests/utils.py index 431252a6f1..3b1eb50d8d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -46,7 +46,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config = Mock() config.signing_key = [MockKey()] config.event_cache_size = 1 - config.disable_registration = False + config.enable_registration = True config.macaroon_secret_key = "not even a little secret" config.server_name = "server.under.test" config.trusted_third_party_id_servers = [] -- cgit 1.5.1 From 6a9f1209dfe5b3c43726aff24000129856bdc084 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Fri, 5 Feb 2016 01:58:23 +0000 Subject: Error if macaroon key is missing from config Currently we store all access tokens in the DB, and fall back to that check if we can't validate the macaroon, so our fallback works here, but for guests, their macaroons don't get persisted, so we don't get to find them in the database. Each restart, we generate a new ephemeral key, so guests lose access after each server restart. I tried to fix up the config stuff to be less insane, but gave up, so instead I bolt on yet another piece of custom one-off insanity. Also, add some basic tests for config generation and loading. --- synapse/app/homeserver.py | 20 ++++++++--- synapse/config/__main__.py | 7 +++- synapse/config/_base.py | 35 ++++++++++++------- synapse/config/registration.py | 18 +++++++--- tests/config/__init__.py | 14 ++++++++ tests/config/test_generate.py | 50 +++++++++++++++++++++++++++ tests/config/test_load.py | 77 ++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 198 insertions(+), 23 deletions(-) create mode 100644 tests/config/__init__.py create mode 100644 tests/config/test_generate.py create mode 100644 tests/config/test_load.py (limited to 'tests') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 0a6a19033d..89238cb7e3 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -24,6 +24,7 @@ import resource import subprocess import sys import time +from synapse.config._base import ConfigError from synapse.python_dependencies import ( check_requirements, DEPENDENCY_LINKS @@ -350,11 +351,20 @@ def setup(config_options): Returns: HomeServer """ - config = HomeServerConfig.load_config( - "Synapse Homeserver", - config_options, - generate_section="Homeserver" - ) + try: + config = HomeServerConfig.load_config( + "Synapse Homeserver", + config_options, + generate_section="Homeserver" + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + if not config: + # If a config isn't returned, and an exception isn't raised, we're just + # generating config files and shouldn't try to continue. + sys.exit(0) config.setup_logging() diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py index ea9e7907a6..0a3b70e11f 100644 --- a/synapse/config/__main__.py +++ b/synapse/config/__main__.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from synapse.config._base import ConfigError if __name__ == "__main__": import sys @@ -21,7 +22,11 @@ if __name__ == "__main__": if action == "read": key = sys.argv[2] - config = HomeServerConfig.load_config("", sys.argv[3:]) + try: + config = HomeServerConfig.load_config("", sys.argv[3:]) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) print getattr(config, key) sys.exit(0) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index a9304a11ba..15d78ff33a 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -17,7 +17,6 @@ import argparse import errno import os import yaml -import sys from textwrap import dedent @@ -136,13 +135,20 @@ class Config(object): results.append(getattr(cls, name)(self, *args, **kargs)) return results - def generate_config(self, config_dir_path, server_name, report_stats=None): + def generate_config( + self, + config_dir_path, + server_name, + is_generating_file, + report_stats=None, + ): default_config = "# vim:ft=yaml\n" default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( "default_config", config_dir_path=config_dir_path, server_name=server_name, + is_generating_file=is_generating_file, report_stats=report_stats, )) @@ -244,8 +250,10 @@ class Config(object): server_name = config_args.server_name if not server_name: - print "Must specify a server_name to a generate config for." - sys.exit(1) + raise ConfigError( + "Must specify a server_name to a generate config for." + " Pass -H server.name." + ) if not os.path.exists(config_dir_path): os.makedirs(config_dir_path) with open(config_path, "wb") as config_file: @@ -253,6 +261,7 @@ class Config(object): config_dir_path=config_dir_path, server_name=server_name, report_stats=(config_args.report_stats == "yes"), + is_generating_file=True ) obj.invoke_all("generate_files", config) config_file.write(config_bytes) @@ -266,7 +275,7 @@ class Config(object): "If this server name is incorrect, you will need to" " regenerate the SSL certificates" ) - sys.exit(0) + return else: print ( "Config file %r already exists. Generating any missing key" @@ -302,25 +311,25 @@ class Config(object): specified_config.update(yaml_config) if "server_name" not in specified_config: - sys.stderr.write("\n" + MISSING_SERVER_NAME + "\n") - sys.exit(1) + raise ConfigError(MISSING_SERVER_NAME) server_name = specified_config["server_name"] _, config = obj.generate_config( config_dir_path=config_dir_path, - server_name=server_name + server_name=server_name, + is_generating_file=False, ) config.pop("log_config") config.update(specified_config) if "report_stats" not in config: - sys.stderr.write( - "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + - MISSING_REPORT_STATS_SPIEL + "\n") - sys.exit(1) + raise ConfigError( + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + + MISSING_REPORT_STATS_SPIEL + ) if generate_keys: obj.invoke_all("generate_files", config) - sys.exit(0) + return obj.invoke_all("read_config", config) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 90ea19bd4b..9b6dacc5b8 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -33,12 +33,24 @@ class RegistrationConfig(Config): self.registration_shared_secret = config.get("registration_shared_secret") self.macaroon_secret_key = config.get("macaroon_secret_key") + if self.macaroon_secret_key is None: + raise Exception( + "Config is missing missing macaroon_secret_key - please set it" + " in your config file." + ) self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"] self.allow_guest_access = config.get("allow_guest_access", False) - def default_config(self, **kwargs): + def default_config(self, is_generating_file=False, **kwargs): registration_shared_secret = random_string_with_symbols(50) + + macaroon_line = "" + if is_generating_file: + macaroon_line += '\n macaroon_secret_key: "%s"\n' % ( + random_string_with_symbols(50), + ) + macaroon_secret_key = random_string_with_symbols(50) return """\ ## Registration ## @@ -49,9 +61,7 @@ class RegistrationConfig(Config): # If set, allows registration by anyone who also has the shared # secret, even if registration is otherwise disabled. registration_shared_secret: "%(registration_shared_secret)s" - - macaroon_secret_key: "%(macaroon_secret_key)s" - +%(macaroon_line)s # Set the number of bcrypt rounds used to generate password hash. # Larger numbers increase the work factor needed to generate the hash. # The default number of rounds is 12. diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/tests/config/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py new file mode 100644 index 0000000000..4329d73974 --- /dev/null +++ b/tests/config/test_generate.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os.path +import shutil +import tempfile +from synapse.config.homeserver import HomeServerConfig +from tests import unittest + + +class ConfigGenerationTestCase(unittest.TestCase): + + def setUp(self): + self.dir = tempfile.mkdtemp() + print self.dir + self.file = os.path.join(self.dir, "homeserver.yaml") + + def tearDown(self): + shutil.rmtree(self.dir) + + def test_generate_config_generates_files(self): + HomeServerConfig.load_config("", [ + "--generate-config", + "-c", self.file, + "--report-stats=yes", + "-H", "lemurs.win" + ]) + + self.assertSetEqual( + set([ + "homeserver.yaml", + "lemurs.win.log.config", + "lemurs.win.signing.key", + "lemurs.win.tls.crt", + "lemurs.win.tls.dh", + "lemurs.win.tls.key", + ]), + set(os.listdir(self.dir)) + ) diff --git a/tests/config/test_load.py b/tests/config/test_load.py new file mode 100644 index 0000000000..7f41279715 --- /dev/null +++ b/tests/config/test_load.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os.path +import shutil +import tempfile +import yaml +from synapse.config.homeserver import HomeServerConfig +from tests import unittest + + +class ConfigLoadingTestCase(unittest.TestCase): + + def setUp(self): + self.dir = tempfile.mkdtemp() + print self.dir + self.file = os.path.join(self.dir, "homeserver.yaml") + + def tearDown(self): + shutil.rmtree(self.dir) + + def test_load_fails_if_server_name_missing(self): + self.generate_config_and_remove_lines_containing("server_name") + with self.assertRaises(Exception): + HomeServerConfig.load_config("", ["-c", self.file]) + + def test_generates_and_loads_macaroon_secret_key(self): + self.generate_config() + + with open(self.file, + "r") as f: + raw = yaml.load(f) + self.assertIn("macaroon_secret_key", raw) + + config = HomeServerConfig.load_config("", ["-c", self.file]) + self.assertTrue( + hasattr(config, "macaroon_secret_key"), + "Want config to have attr macaroon_secret_key" + ) + if len(config.macaroon_secret_key) < 5: + self.fail( + "Want macaroon secret key to be string of at least length 5," + "was: %r" % (config.macaroon_secret_key,) + ) + + def test_load_fails_if_macaroon_secret_key_missing(self): + self.generate_config_and_remove_lines_containing("macaroon") + with self.assertRaises(Exception): + HomeServerConfig.load_config("", ["-c", self.file]) + + def generate_config(self): + HomeServerConfig.load_config("", [ + "--generate-config", + "-c", self.file, + "--report-stats=yes", + "-H", "lemurs.win" + ]) + + def generate_config_and_remove_lines_containing(self, needle): + self.generate_config() + + with open(self.file, "r") as f: + contents = f.readlines() + contents = [l for l in contents if needle not in l] + with open(self.file, "w") as f: + f.write("".join(contents)) -- cgit 1.5.1 From 13ba8d878ce30dbc16123886a78a0905fc9ad4a5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 9 Feb 2016 14:55:21 +0000 Subject: Fix test --- tests/config/test_load.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 7f41279715..528e878532 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -54,10 +54,11 @@ class ConfigLoadingTestCase(unittest.TestCase): "was: %r" % (config.macaroon_secret_key,) ) - def test_load_fails_if_macaroon_secret_key_missing(self): + def test_load_suceeds_if_macaroon_secret_key_missing(self): self.generate_config_and_remove_lines_containing("macaroon") - with self.assertRaises(Exception): - HomeServerConfig.load_config("", ["-c", self.file]) + config1 = HomeServerConfig.load_config("", ["-c", self.file]) + config2 = HomeServerConfig.load_config("", ["-c", self.file]) + self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key) def generate_config(self): HomeServerConfig.load_config("", [ -- cgit 1.5.1 From e664e9737ca8ff04043f747a9375ff50440352c2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 9 Feb 2016 14:57:43 +0000 Subject: Fix test --- tests/util/test_log_context.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) (limited to 'tests') diff --git a/tests/util/test_log_context.py b/tests/util/test_log_context.py index efa0f28bad..65a330a0e9 100644 --- a/tests/util/test_log_context.py +++ b/tests/util/test_log_context.py @@ -5,6 +5,7 @@ from .. import unittest from synapse.util.async import sleep from synapse.util.logcontext import LoggingContext + class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): @@ -17,15 +18,6 @@ class LoggingContextTestCase(unittest.TestCase): context_one.test_key = "test" self._check_test_key("test") - def test_chaining(self): - with LoggingContext() as context_one: - context_one.test_key = "one" - with LoggingContext() as context_two: - self._check_test_key("one") - context_two.test_key = "two" - self._check_test_key("two") - self._check_test_key("one") - @defer.inlineCallbacks def test_sleep(self): @defer.inlineCallbacks -- cgit 1.5.1 From 78a54822670c94763e6de708797fd561260bbcf5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 9 Feb 2016 16:23:11 +0000 Subject: Typo --- tests/config/test_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 528e878532..fbbbf93fef 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -54,7 +54,7 @@ class ConfigLoadingTestCase(unittest.TestCase): "was: %r" % (config.macaroon_secret_key,) ) - def test_load_suceeds_if_macaroon_secret_key_missing(self): + def test_load_succeeds_if_macaroon_secret_key_missing(self): self.generate_config_and_remove_lines_containing("macaroon") config1 = HomeServerConfig.load_config("", ["-c", self.file]) config2 = HomeServerConfig.load_config("", ["-c", self.file]) -- cgit 1.5.1