From 2d3462714e48dca46dd54b17ca29188a17261e28 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 18 Aug 2015 14:22:02 +0100 Subject: Issue macaroons as opaque auth tokens This just replaces random bytes with macaroons. The macaroons are not inspected by the client or server. In particular, they claim to have an expiry time, but nothing verifies that they have not expired. Follow-up commits will actually enforce the expiration, and allow for token refresh. See https://bit.ly/matrix-auth for more information --- tests/handlers/test_register.py | 70 +++++++++++++++++++++++++++++++++++++++++ tests/utils.py | 2 ++ 2 files changed, 72 insertions(+) create mode 100644 tests/handlers/test_register.py (limited to 'tests') diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py new file mode 100644 index 0000000000..b28b1a7ef0 --- /dev/null +++ b/tests/handlers/test_register.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pymacaroons + +from mock import Mock, NonCallableMock +from synapse.handlers.register import RegistrationHandler +from tests import unittest +from tests.utils import setup_test_homeserver +from twisted.internet import defer + + +class RegisterHandlers(object): + def __init__(self, hs): + self.registration_handler = RegistrationHandler(hs) + + +class RegisterTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver(handlers=None) + self.hs.handlers = RegisterHandlers(self.hs) + + def test_token_is_a_macaroon(self): + self.hs.config.macaroon_secret_key = "this key is a huge secret" + + token = self.hs.handlers.registration_handler.generate_token("some_user") + # Check that we can parse the thing with pymacaroons + macaroon = pymacaroons.Macaroon.deserialize(token) + # The most basic of sanity checks + if "some_user" not in macaroon.inspect(): + self.fail("some_user was not in %s" % macaroon.inspect()) + + def test_macaroon_caveats(self): + self.hs.config.macaroon_secret_key = "this key is a massive secret" + self.hs.clock.now = 5000 + + token = self.hs.handlers.registration_handler.generate_token("a_user") + macaroon = pymacaroons.Macaroon.deserialize(token) + + def verify_gen(caveat): + return caveat == "gen = 1" + + def verify_user(caveat): + return caveat == "user_id = a_user" + + def verify_type(caveat): + return caveat == "type = access" + + def verify_expiry(caveat): + return caveat == "time < 8600" + + v = pymacaroons.Verifier() + v.satisfy_general(verify_gen) + v.satisfy_general(verify_user) + v.satisfy_general(verify_type) + v.satisfy_general(verify_expiry) + v.verify(macaroon, self.hs.config.macaroon_secret_key) \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py index eb035cf48f..80be70b74f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -44,6 +44,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.signing_key = [MockKey()] config.event_cache_size = 1 config.disable_registration = False + config.macaroon_secret_key = "not even a little secret" + config.server_name = "server.under.test" if "clock" not in kargs: kargs["clock"] = MockClock() -- cgit 1.5.1 From 3e6fdfda002de6971b74aba7805ebdeb2b1b426d Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 18 Aug 2015 15:18:50 +0100 Subject: Fix some formatting to use tuples --- synapse/handlers/register.py | 8 ++++---- tests/handlers/test_register.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 86bacdda1d..c391c1bdf5 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -279,11 +279,11 @@ class RegistrationHandler(BaseHandler): identifier = "key", key = self.hs.config.macaroon_secret_key) macaroon.add_first_party_caveat("gen = 1") - macaroon.add_first_party_caveat("user_id = %s" % user_id) + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("type = access") - now = self.hs.get_clock().time() - expiry = now + 60 * 60 - macaroon.add_first_party_caveat("time < %s" % expiry) + now = self.hs.get_clock().time_msec() + expiry = now + (60 * 60 * 1000) + macaroon.add_first_party_caveat("time < %d" % (expiry,)) return macaroon.serialize() diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index b28b1a7ef0..0766affe81 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -67,4 +67,4 @@ class RegisterTestCase(unittest.TestCase): v.satisfy_general(verify_user) v.satisfy_general(verify_type) v.satisfy_general(verify_expiry) - v.verify(macaroon, self.hs.config.macaroon_secret_key) \ No newline at end of file + v.verify(macaroon, self.hs.config.macaroon_secret_key) -- cgit 1.5.1 From 42e858daeb59b86c451e3f49d40c1f418c8f0748 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 18 Aug 2015 17:38:37 +0100 Subject: Fix units in test I made the non-test seconds instead of ms, but not the test --- tests/handlers/test_register.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 0766affe81..91cc90242f 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -60,7 +60,7 @@ class RegisterTestCase(unittest.TestCase): return caveat == "type = access" def verify_expiry(caveat): - return caveat == "time < 8600" + return caveat == "time < 8600000" v = pymacaroons.Verifier() v.satisfy_general(verify_gen) -- cgit 1.5.1 From ce832c38d4ba1412cd5b5f8a4fb9328cb2d444fa Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 18 Aug 2015 17:39:26 +0100 Subject: Remove padding space around caveat operators --- synapse/handlers/register.py | 8 ++++---- tests/handlers/test_register.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c391c1bdf5..557aec4e6c 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -278,12 +278,12 @@ class RegistrationHandler(BaseHandler): location = self.hs.config.server_name, identifier = "key", key = self.hs.config.macaroon_secret_key) - macaroon.add_first_party_caveat("gen = 1") - macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("gen=1") + macaroon.add_first_party_caveat("user_id=%s" % (user_id,)) + macaroon.add_first_party_caveat("type=access") now = self.hs.get_clock().time_msec() expiry = now + (60 * 60 * 1000) - macaroon.add_first_party_caveat("time < %d" % (expiry,)) + macaroon.add_first_party_caveat("time<%d" % (expiry,)) return macaroon.serialize() diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 91cc90242f..18507c547d 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -51,16 +51,16 @@ class RegisterTestCase(unittest.TestCase): macaroon = pymacaroons.Macaroon.deserialize(token) def verify_gen(caveat): - return caveat == "gen = 1" + return caveat == "gen=1" def verify_user(caveat): - return caveat == "user_id = a_user" + return caveat == "user_id=a_user" def verify_type(caveat): - return caveat == "type = access" + return caveat == "type=access" def verify_expiry(caveat): - return caveat == "time < 8600000" + return caveat == "time<8600000" v = pymacaroons.Verifier() v.satisfy_general(verify_gen) -- cgit 1.5.1 From 70e265e695a67a412b5ac76cc9bae71e9d384e80 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Wed, 19 Aug 2015 14:30:31 +0100 Subject: Re-add whitespace around caveat operators --- synapse/handlers/register.py | 8 ++++---- tests/handlers/test_register.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 557aec4e6c..c391c1bdf5 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -278,12 +278,12 @@ class RegistrationHandler(BaseHandler): location = self.hs.config.server_name, identifier = "key", key = self.hs.config.macaroon_secret_key) - macaroon.add_first_party_caveat("gen=1") - macaroon.add_first_party_caveat("user_id=%s" % (user_id,)) - macaroon.add_first_party_caveat("type=access") + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + macaroon.add_first_party_caveat("type = access") now = self.hs.get_clock().time_msec() expiry = now + (60 * 60 * 1000) - macaroon.add_first_party_caveat("time<%d" % (expiry,)) + macaroon.add_first_party_caveat("time < %d" % (expiry,)) return macaroon.serialize() diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 18507c547d..91cc90242f 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -51,16 +51,16 @@ class RegisterTestCase(unittest.TestCase): macaroon = pymacaroons.Macaroon.deserialize(token) def verify_gen(caveat): - return caveat == "gen=1" + return caveat == "gen = 1" def verify_user(caveat): - return caveat == "user_id=a_user" + return caveat == "user_id = a_user" def verify_type(caveat): - return caveat == "type=access" + return caveat == "type = access" def verify_expiry(caveat): - return caveat == "time<8600000" + return caveat == "time < 8600000" v = pymacaroons.Verifier() v.satisfy_general(verify_gen) -- cgit 1.5.1 From 617501dd2a0562f4bf7edf8bc7a4e8aeb16b2254 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 20 Aug 2015 11:35:56 +0100 Subject: Move token generation to auth handler I prefer the auth handler to worry about all auth, and register to call into it as needed, than to smatter auth logic between the two. --- synapse/handlers/auth.py | 29 ++++++++++++++--- synapse/handlers/register.py | 26 +++++---------- tests/handlers/test_auth.py | 70 +++++++++++++++++++++++++++++++++++++++++ tests/handlers/test_register.py | 70 ----------------------------------------- 4 files changed, 101 insertions(+), 94 deletions(-) create mode 100644 tests/handlers/test_auth.py delete mode 100644 tests/handlers/test_register.py (limited to 'tests') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index be2baeaece..0bf917efdd 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -26,6 +26,7 @@ from twisted.web.client import PartialDownloadError import logging import bcrypt +import pymacaroons import simplejson import synapse.util.stringutils as stringutils @@ -284,12 +285,9 @@ class AuthHandler(BaseHandler): LoginError if there was an authentication problem. """ yield self._check_password(user_id, password) - - reg_handler = self.hs.get_handlers().registration_handler - access_token = reg_handler.generate_token(user_id) logger.info("Logging in user %s", user_id) - yield self.store.add_access_token_to_user(user_id, access_token) - defer.returnValue(access_token) + token = yield self.issue_access_token(user_id) + defer.returnValue(token) @defer.inlineCallbacks def _check_password(self, user_id, password): @@ -304,6 +302,27 @@ class AuthHandler(BaseHandler): logger.warn("Failed password login for user %s", user_id) raise LoginError(403, "", errcode=Codes.FORBIDDEN) + @defer.inlineCallbacks + def issue_access_token(self, user_id): + reg_handler = self.hs.get_handlers().registration_handler + access_token = reg_handler.generate_access_token(user_id) + yield self.store.add_access_token_to_user(user_id, access_token) + defer.returnValue(access_token) + + def generate_access_token(self, user_id): + macaroon = pymacaroons.Macaroon( + location = self.hs.config.server_name, + identifier = "key", + key = self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + macaroon.add_first_party_caveat("type = access") + now = self.hs.get_clock().time_msec() + expiry = now + (60 * 60 * 1000) + macaroon.add_first_party_caveat("time < %d" % (expiry,)) + + return macaroon.serialize() + @defer.inlineCallbacks def set_password(self, user_id, newpassword): password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c391c1bdf5..3d1b6531c2 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -27,7 +27,6 @@ from synapse.http.client import CaptchaServerHttpClient import bcrypt import logging -import pymacaroons import urllib logger = logging.getLogger(__name__) @@ -91,7 +90,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.generate_token(user_id) + token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -111,7 +110,7 @@ class RegistrationHandler(BaseHandler): user_id = user.to_string() yield self.check_user_id_is_valid(user_id) - token = self.generate_token(user_id) + token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -161,7 +160,7 @@ class RegistrationHandler(BaseHandler): 400, "Invalid user localpart for this application service.", errcode=Codes.EXCLUSIVE ) - token = self.generate_token(user_id) + token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -208,7 +207,7 @@ class RegistrationHandler(BaseHandler): user_id = user.to_string() yield self.check_user_id_is_valid(user_id) - token = self.generate_token(user_id) + token = self.auth_handler().generate_access_token(user_id) try: yield self.store.register( user_id=user_id, @@ -273,20 +272,6 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE ) - def generate_token(self, user_id): - macaroon = pymacaroons.Macaroon( - location = self.hs.config.server_name, - identifier = "key", - key = self.hs.config.macaroon_secret_key) - macaroon.add_first_party_caveat("gen = 1") - macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - macaroon.add_first_party_caveat("type = access") - now = self.hs.get_clock().time_msec() - expiry = now + (60 * 60 * 1000) - macaroon.add_first_party_caveat("time < %d" % (expiry,)) - - return macaroon.serialize() - def _generate_user_id(self): return "-" + stringutils.random_string(18) @@ -329,3 +314,6 @@ class RegistrationHandler(BaseHandler): } ) defer.returnValue(data) + + def auth_handler(self): + return self.hs.get_handlers().auth_handler diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py new file mode 100644 index 0000000000..978e4d0d2e --- /dev/null +++ b/tests/handlers/test_auth.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pymacaroons + +from mock import Mock, NonCallableMock +from synapse.handlers.auth import AuthHandler +from tests import unittest +from tests.utils import setup_test_homeserver +from twisted.internet import defer + + +class AuthHandlers(object): + def __init__(self, hs): + self.auth_handler = AuthHandler(hs) + + +class AuthTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver(handlers=None) + self.hs.handlers = AuthHandlers(self.hs) + + def test_token_is_a_macaroon(self): + self.hs.config.macaroon_secret_key = "this key is a huge secret" + + token = self.hs.handlers.auth_handler.generate_access_token("some_user") + # Check that we can parse the thing with pymacaroons + macaroon = pymacaroons.Macaroon.deserialize(token) + # The most basic of sanity checks + if "some_user" not in macaroon.inspect(): + self.fail("some_user was not in %s" % macaroon.inspect()) + + def test_macaroon_caveats(self): + self.hs.config.macaroon_secret_key = "this key is a massive secret" + self.hs.clock.now = 5000 + + token = self.hs.handlers.auth_handler.generate_access_token("a_user") + macaroon = pymacaroons.Macaroon.deserialize(token) + + def verify_gen(caveat): + return caveat == "gen = 1" + + def verify_user(caveat): + return caveat == "user_id = a_user" + + def verify_type(caveat): + return caveat == "type = access" + + def verify_expiry(caveat): + return caveat == "time < 8600000" + + v = pymacaroons.Verifier() + v.satisfy_general(verify_gen) + v.satisfy_general(verify_user) + v.satisfy_general(verify_type) + v.satisfy_general(verify_expiry) + v.verify(macaroon, self.hs.config.macaroon_secret_key) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py deleted file mode 100644 index 91cc90242f..0000000000 --- a/tests/handlers/test_register.py +++ /dev/null @@ -1,70 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pymacaroons - -from mock import Mock, NonCallableMock -from synapse.handlers.register import RegistrationHandler -from tests import unittest -from tests.utils import setup_test_homeserver -from twisted.internet import defer - - -class RegisterHandlers(object): - def __init__(self, hs): - self.registration_handler = RegistrationHandler(hs) - - -class RegisterTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(handlers=None) - self.hs.handlers = RegisterHandlers(self.hs) - - def test_token_is_a_macaroon(self): - self.hs.config.macaroon_secret_key = "this key is a huge secret" - - token = self.hs.handlers.registration_handler.generate_token("some_user") - # Check that we can parse the thing with pymacaroons - macaroon = pymacaroons.Macaroon.deserialize(token) - # The most basic of sanity checks - if "some_user" not in macaroon.inspect(): - self.fail("some_user was not in %s" % macaroon.inspect()) - - def test_macaroon_caveats(self): - self.hs.config.macaroon_secret_key = "this key is a massive secret" - self.hs.clock.now = 5000 - - token = self.hs.handlers.registration_handler.generate_token("a_user") - macaroon = pymacaroons.Macaroon.deserialize(token) - - def verify_gen(caveat): - return caveat == "gen = 1" - - def verify_user(caveat): - return caveat == "user_id = a_user" - - def verify_type(caveat): - return caveat == "type = access" - - def verify_expiry(caveat): - return caveat == "time < 8600000" - - v = pymacaroons.Verifier() - v.satisfy_general(verify_gen) - v.satisfy_general(verify_user) - v.satisfy_general(verify_type) - v.satisfy_general(verify_expiry) - v.verify(macaroon, self.hs.config.macaroon_secret_key) -- cgit 1.5.1 From 13a6517d89c0619a938321640f331571eba0edc9 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 20 Aug 2015 16:01:29 +0100 Subject: s/by_token/by_access_token/g We're about to have two kinds of token, access and refresh --- synapse/api/auth.py | 6 +++--- synapse/storage/registration.py | 6 +++--- tests/api/test_auth.py | 16 ++++++++-------- tests/rest/client/v1/test_presence.py | 8 ++++---- tests/rest/client/v1/test_rooms.py | 28 ++++++++++++++-------------- tests/rest/client/v1/test_typing.py | 4 ++-- tests/rest/client/v1/utils.py | 2 +- tests/rest/client/v2_alpha/__init__.py | 4 ++-- tests/storage/test_registration.py | 4 ++-- tests/utils.py | 2 +- 10 files changed, 40 insertions(+), 40 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 1e3b0fbfb7..3d9237ccc3 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -361,7 +361,7 @@ class Auth(object): except KeyError: pass # normal users won't have the user_id query parameter set. - user_info = yield self.get_user_by_token(access_token) + user_info = yield self.get_user_by_access_token(access_token) user = user_info["user"] device_id = user_info["device_id"] token_id = user_info["token_id"] @@ -390,7 +390,7 @@ class Auth(object): ) @defer.inlineCallbacks - def get_user_by_token(self, token): + def get_user_by_access_token(self, token): """ Get a registered user's ID. Args: @@ -401,7 +401,7 @@ class Auth(object): Raises: AuthError if no user by that token exists or the token is invalid. """ - ret = yield self.store.get_user_by_token(token) + ret = yield self.store.get_user_by_access_token(token) if not ret: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index bf803f2c6e..0e404afb7c 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -132,10 +132,10 @@ class RegistrationStore(SQLBaseStore): user_id ) for r in rows: - self.get_user_by_token.invalidate((r,)) + self.get_user_by_access_token.invalidate((r,)) @cached() - def get_user_by_token(self, token): + def get_user_by_access_token(self, token): """Get a user from the given access token. Args: @@ -147,7 +147,7 @@ class RegistrationStore(SQLBaseStore): StoreError if no user was found. """ return self.runInteraction( - "get_user_by_token", + "get_user_by_access_token", self._query_for_auth, token ) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 4f83db5e84..3343c635cc 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -44,7 +44,7 @@ class AuthTestCase(unittest.TestCase): "token_id": "ditto", "admin": False } - self.store.get_user_by_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock(return_value=user_info) request = Mock(args={}) request.args["access_token"] = [self.test_token] @@ -54,7 +54,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_user_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.args["access_token"] = [self.test_token] @@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase): "token_id": "ditto", "admin": False } - self.store.get_user_by_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock(return_value=user_info) request = Mock(args={}) request.requestHeaders.getRawHeaders = Mock(return_value=[""]) @@ -81,7 +81,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_valid_token(self): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.args["access_token"] = [self.test_token] @@ -91,7 +91,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.args["access_token"] = [self.test_token] @@ -102,7 +102,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_missing_token(self): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.requestHeaders.getRawHeaders = Mock(return_value=[""]) @@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.args["access_token"] = [self.test_token] @@ -129,7 +129,7 @@ class AuthTestCase(unittest.TestCase): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.args["access_token"] = [self.test_token] diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 089a71568c..0b78a82a66 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -70,7 +70,7 @@ class PresenceStateTestCase(unittest.TestCase): return defer.succeed([]) self.datastore.get_presence_list = get_presence_list - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(myid), "admin": False, @@ -78,7 +78,7 @@ class PresenceStateTestCase(unittest.TestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token room_member_handler = hs.handlers.room_member_handler = Mock( spec=[ @@ -159,7 +159,7 @@ class PresenceListTestCase(unittest.TestCase): ) self.datastore.has_presence_state = has_presence_state - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(myid), "admin": False, @@ -173,7 +173,7 @@ class PresenceListTestCase(unittest.TestCase): ] ) - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token presence.register_servlets(hs, self.mock_resource) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index c83348acf9..2e55cc08a1 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -54,14 +54,14 @@ class RoomPermissionsTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -441,14 +441,14 @@ class RoomsMemberListTestCase(RestTestCase): self.auth_user_id = self.user_id - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -521,14 +521,14 @@ class RoomsCreateTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, @@ -622,7 +622,7 @@ class RoomTopicTestCase(RestTestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -721,14 +721,14 @@ class RoomMemberStateTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -848,14 +848,14 @@ class RoomMessagesTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -945,14 +945,14 @@ class RoomInitialSyncTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 7d8b1c2683..dc8bbaaf0e 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -61,7 +61,7 @@ class RoomTypingTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, @@ -69,7 +69,7 @@ class RoomTypingTestCase(RestTestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 579441fb4a..c472d53043 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -37,7 +37,7 @@ class RestTestCase(unittest.TestCase): self.mock_resource = None self.auth_user_id = None - def mock_get_user_by_token(self, token=None): + def mock_get_user_by_access_token(self, token=None): return self.auth_user_id @defer.inlineCallbacks diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index de5a917e6a..15568b36cd 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -43,14 +43,14 @@ class V2AlphaRestTestCase(unittest.TestCase): resource_for_federation=self.mock_resource, ) - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.USER_ID), "admin": False, "device_id": None, "token_id": 1, } - hs.get_auth().get_user_by_token = _get_user_by_token + hs.get_auth().get_user_by_access_token = _get_user_by_access_token for r in self.TO_REGISTER: r.register_servlets(hs, self.mock_resource) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 2702291178..7a24cf898a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase): (yield self.store.get_user_by_id(self.user_id)) ) - result = yield self.store.get_user_by_token(self.tokens[0]) + result = yield self.store.get_user_by_access_token(self.tokens[0]) self.assertDictContainsSubset( { @@ -64,7 +64,7 @@ class RegistrationStoreTestCase(unittest.TestCase): yield self.store.register(self.user_id, self.tokens[0], self.pwhash) yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) - result = yield self.store.get_user_by_token(self.tokens[1]) + result = yield self.store.get_user_by_access_token(self.tokens[1]) self.assertDictContainsSubset( { diff --git a/tests/utils.py b/tests/utils.py index 80be70b74f..d0fba2252d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -277,7 +277,7 @@ class MemoryDataStore(object): raise StoreError(400, "User in use.") self.tokens_to_users[token] = user_id - def get_user_by_token(self, token): + def get_user_by_access_token(self, token): try: return { "name": self.tokens_to_users[token], -- cgit 1.5.1 From cecbd636e94f4e900ef6d246b62698ff1c8ee352 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 20 Aug 2015 16:21:35 +0100 Subject: /tokenrefresh POST endpoint This allows refresh tokens to be exchanged for (access_token, refresh_token). It also starts issuing them on login, though no clients currently interpret them. --- synapse/handlers/auth.py | 35 ++++++++++-- synapse/rest/client/v1/login.py | 6 ++- synapse/rest/client/v2_alpha/__init__.py | 2 + synapse/rest/client/v2_alpha/tokenrefresh.py | 56 +++++++++++++++++++ synapse/storage/__init__.py | 2 +- synapse/storage/_base.py | 1 + synapse/storage/registration.py | 62 ++++++++++++++++++++++ synapse/storage/schema/delta/23/refresh_tokens.sql | 21 ++++++++ tests/storage/test_registration.py | 55 +++++++++++++++++++ 9 files changed, 232 insertions(+), 8 deletions(-) create mode 100644 synapse/rest/client/v2_alpha/tokenrefresh.py create mode 100644 synapse/storage/schema/delta/23/refresh_tokens.sql (limited to 'tests') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0bf917efdd..65bd8189dc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -279,15 +279,18 @@ class AuthHandler(BaseHandler): user_id (str): User ID password (str): Password Returns: - The access token for the user's session. + A tuple of: + The access token for the user's session. + The refresh token for the user's session. Raises: StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ yield self._check_password(user_id, password) logger.info("Logging in user %s", user_id) - token = yield self.issue_access_token(user_id) - defer.returnValue(token) + access_token = yield self.issue_access_token(user_id) + refresh_token = yield self.issue_refresh_token(user_id) + defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks def _check_password(self, user_id, password): @@ -304,11 +307,16 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def issue_access_token(self, user_id): - reg_handler = self.hs.get_handlers().registration_handler - access_token = reg_handler.generate_access_token(user_id) + access_token = self.generate_access_token(user_id) yield self.store.add_access_token_to_user(user_id, access_token) defer.returnValue(access_token) + @defer.inlineCallbacks + def issue_refresh_token(self, user_id): + refresh_token = self.generate_refresh_token(user_id) + yield self.store.add_refresh_token_to_user(user_id, refresh_token) + defer.returnValue(refresh_token) + def generate_access_token(self, user_id): macaroon = pymacaroons.Macaroon( location = self.hs.config.server_name, @@ -323,6 +331,23 @@ class AuthHandler(BaseHandler): return macaroon.serialize() + def generate_refresh_token(self, user_id): + m = self._generate_base_macaroon(user_id) + m.add_first_party_caveat("type = refresh") + # Important to add a nonce, because otherwise every refresh token for a + # user will be the same. + m.add_first_party_caveat("nonce = %s" % stringutils.random_string_with_symbols(16)) + return m.serialize() + + def _generate_base_macaroon(self, user_id): + macaroon = pymacaroons.Macaroon( + location = self.hs.config.server_name, + identifier = "key", + key = self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + return macaroon + @defer.inlineCallbacks def set_password(self, user_id, newpassword): password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 694072693d..b963a38618 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -78,13 +78,15 @@ class LoginRestServlet(ClientV1RestServlet): login_submission["user"] = UserID.create( login_submission["user"], self.hs.hostname).to_string() - token = yield self.handlers.auth_handler.login_with_password( + auth_handler = self.handlers.auth_handler + access_token, refresh_token = yield auth_handler.login_with_password( user_id=login_submission["user"], password=login_submission["password"]) result = { "user_id": login_submission["user"], # may have changed - "access_token": token, + "access_token": access_token, + "refresh_token": refresh_token, "home_server": self.hs.hostname, } diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py index 33f961e898..5831ff0e62 100644 --- a/synapse/rest/client/v2_alpha/__init__.py +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -21,6 +21,7 @@ from . import ( auth, receipts, keys, + tokenrefresh, ) from synapse.http.server import JsonResource @@ -42,3 +43,4 @@ class ClientV2AlphaRestResource(JsonResource): auth.register_servlets(hs, client_resource) receipts.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource) + tokenrefresh.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py new file mode 100644 index 0000000000..901e777983 --- /dev/null +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import AuthError, StoreError, SynapseError +from synapse.http.servlet import RestServlet + +from ._base import client_v2_pattern, parse_json_dict_from_request + + +class TokenRefreshRestServlet(RestServlet): + """ + Exchanges refresh tokens for a pair of an access token and a new refresh + token. + """ + PATTERN = client_v2_pattern("/tokenrefresh") + + def __init__(self, hs): + super(TokenRefreshRestServlet, self).__init__() + self.hs = hs + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_dict_from_request(request) + try: + old_refresh_token = body["refresh_token"] + auth_handler = self.hs.get_handlers().auth_handler + (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( + old_refresh_token, auth_handler.generate_refresh_token) + new_access_token = yield auth_handler.issue_access_token(user_id) + defer.returnValue((200, { + "access_token": new_access_token, + "refresh_token": new_refresh_token, + })) + except KeyError: + raise SynapseError(400, "Missing required key 'refresh_token'.") + except StoreError: + raise AuthError(403, "Did not recognize refresh token") + + +def register_servlets(hs, http_server): + TokenRefreshRestServlet(hs).register(http_server) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index f154b1c8ae..53673b3bf5 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 22 +SCHEMA_VERSION = 23 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 1444767a52..ce71389f02 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -181,6 +181,7 @@ class SQLBaseStore(object): self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) + self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 0e404afb7c..f632306688 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -50,6 +50,28 @@ class RegistrationStore(SQLBaseStore): desc="add_access_token_to_user", ) + @defer.inlineCallbacks + def add_refresh_token_to_user(self, user_id, token): + """Adds a refresh token for the given user. + + Args: + user_id (str): The user ID. + token (str): The new refresh token to add. + Raises: + StoreError if there was a problem adding this. + """ + next_id = yield self._refresh_tokens_id_gen.get_next() + + yield self._simple_insert( + "refresh_tokens", + { + "id": next_id, + "user_id": user_id, + "token": token + }, + desc="add_refresh_token_to_user", + ) + @defer.inlineCallbacks def register(self, user_id, token, password_hash): """Attempts to register an account. @@ -152,6 +174,46 @@ class RegistrationStore(SQLBaseStore): token ) + def exchange_refresh_token(self, refresh_token, token_generator): + """Exchange a refresh token for a new access token and refresh token. + + Doing so invalidates the old refresh token - refresh tokens are single + use. + + Args: + token (str): The refresh token of a user. + token_generator (fn: str -> str): Function which, when given a + user ID, returns a unique refresh token for that user. This + function must never return the same value twice. + Returns: + tuple of (user_id, refresh_token) + Raises: + StoreError if no user was found with that refresh token. + """ + return self.runInteraction( + "exchange_refresh_token", + self._exchange_refresh_token, + refresh_token, + token_generator + ) + + def _exchange_refresh_token(self, txn, old_token, token_generator): + sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" + txn.execute(sql, (old_token,)) + rows = self.cursor_to_dict(txn) + if not rows: + raise StoreError(403, "Did not recognize refresh token") + user_id = rows[0]["user_id"] + + # TODO(danielwh): Maybe perform a validation on the macaroon that + # macaroon.user_id == user_id. + + new_token = token_generator(user_id) + sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" + txn.execute(sql, (new_token, old_token,)) + + return user_id, new_token + @defer.inlineCallbacks def is_server_admin(self, user): res = yield self._simple_select_one_onecol( diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/23/refresh_tokens.sql new file mode 100644 index 0000000000..46839e016c --- /dev/null +++ b/synapse/storage/schema/delta/23/refresh_tokens.sql @@ -0,0 +1,21 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS refresh_tokens( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token TEXT NOT NULL, + user_id TEXT NOT NULL, + UNIQUE (token) +); diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 7a24cf898a..a4f929796a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -17,7 +17,9 @@ from tests import unittest from twisted.internet import defer +from synapse.api.errors import StoreError from synapse.storage.registration import RegistrationStore +from synapse.util import stringutils from tests.utils import setup_test_homeserver @@ -27,6 +29,7 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver() + self.db_pool = hs.get_db_pool() self.store = RegistrationStore(hs) @@ -77,3 +80,55 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertTrue("token_id" in result) + @defer.inlineCallbacks + def test_exchange_refresh_token_valid(self): + uid = stringutils.random_string(32) + generator = TokenGenerator() + last_token = generator.generate(uid) + + self.db_pool.runQuery( + "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", + (uid, last_token,)) + + (found_user_id, refresh_token) = yield self.store.exchange_refresh_token( + last_token, generator.generate) + self.assertEqual(uid, found_user_id) + + rows = yield self.db_pool.runQuery( + "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, )) + self.assertEqual([(refresh_token,)], rows) + # We issued token 1, then exchanged it for token 2 + expected_refresh_token = u"%s-%d" % (uid, 2,) + self.assertEqual(expected_refresh_token, refresh_token) + + @defer.inlineCallbacks + def test_exchange_refresh_token_none(self): + uid = stringutils.random_string(32) + generator = TokenGenerator() + last_token = generator.generate(uid) + + with self.assertRaises(StoreError): + yield self.store.exchange_refresh_token(last_token, generator.generate) + + @defer.inlineCallbacks + def test_exchange_refresh_token_invalid(self): + uid = stringutils.random_string(32) + generator = TokenGenerator() + last_token = generator.generate(uid) + wrong_token = "%s-wrong" % (last_token,) + + self.db_pool.runQuery( + "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", + (uid, wrong_token,)) + + with self.assertRaises(StoreError): + yield self.store.exchange_refresh_token(last_token, generator.generate) + + +class TokenGenerator: + def __init__(self): + self._last_issued_token = 0 + + def generate(self, user_id): + self._last_issued_token += 1 + return u"%s-%d" % (user_id, self._last_issued_token,) -- cgit 1.5.1 From a0b181bd17cb7ec2a43ed2dbdeb1bb40f3f4373c Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 25 Aug 2015 16:23:06 +0100 Subject: Remove completely unused concepts from codebase Removes device_id and ClientInfo device_id is never actually written, and the matrix.org DB has no non-null entries for it. Right now, it's just cluttering up code. This doesn't remove the columns from the database, because that's fiddly. --- synapse/api/auth.py | 17 ++++++--------- synapse/handlers/admin.py | 1 + synapse/handlers/message.py | 9 +++----- synapse/rest/client/v1/admin.py | 2 +- synapse/rest/client/v1/directory.py | 4 ++-- synapse/rest/client/v1/events.py | 4 ++-- synapse/rest/client/v1/initial_sync.py | 2 +- synapse/rest/client/v1/presence.py | 8 +++---- synapse/rest/client/v1/profile.py | 4 ++-- synapse/rest/client/v1/pusher.py | 4 ++-- synapse/rest/client/v1/room.py | 34 ++++++++++++++--------------- synapse/rest/client/v1/voip.py | 2 +- synapse/rest/client/v2_alpha/account.py | 4 ++-- synapse/rest/client/v2_alpha/filter.py | 4 ++-- synapse/rest/client/v2_alpha/keys.py | 6 ++--- synapse/rest/client/v2_alpha/receipts.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/rest/media/v0/content_repository.py | 2 +- synapse/rest/media/v1/upload_resource.py | 2 +- synapse/storage/__init__.py | 7 +++--- synapse/storage/registration.py | 5 ++--- synapse/types.py | 4 ---- tests/api/test_auth.py | 8 +++---- tests/rest/client/v1/test_presence.py | 2 -- tests/rest/client/v1/test_rooms.py | 7 ------ tests/rest/client/v1/test_typing.py | 1 - tests/rest/client/v2_alpha/__init__.py | 1 - tests/storage/test_registration.py | 2 -- tests/utils.py | 3 +-- 29 files changed, 63 insertions(+), 90 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 3d9237ccc3..1496db7dff 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError from synapse.util.logutils import log_function -from synapse.types import UserID, ClientInfo +from synapse.types import UserID import logging @@ -322,9 +322,9 @@ class Auth(object): Args: request - An HTTP request with an access_token query parameter. Returns: - tuple : of UserID and device string: - User ID object of the user making the request - ClientInfo object of the client instance the user is using + tuple of: + UserID (str) + Access token ID (str) Raises: AuthError if no user by that token exists or the token is invalid. """ @@ -355,7 +355,7 @@ class Auth(object): request.authenticated_entity = user_id defer.returnValue( - (UserID.from_string(user_id), ClientInfo("", "")) + (UserID.from_string(user_id), "") ) return except KeyError: @@ -363,7 +363,6 @@ class Auth(object): user_info = yield self.get_user_by_access_token(access_token) user = user_info["user"] - device_id = user_info["device_id"] token_id = user_info["token_id"] ip_addr = self.hs.get_ip_from_request(request) @@ -375,14 +374,13 @@ class Auth(object): self.store.insert_client_ip( user=user, access_token=access_token, - device_id=user_info["device_id"], ip=ip_addr, user_agent=user_agent ) request.authenticated_entity = user.to_string() - defer.returnValue((user, ClientInfo(device_id, token_id))) + defer.returnValue((user, token_id,)) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", @@ -396,7 +394,7 @@ class Auth(object): Args: token (str): The access token to get the user by. Returns: - dict : dict that includes the user, device_id, and whether the + dict : dict that includes the user and whether the user is a server admin. Raises: AuthError if no user by that token exists or the token is invalid. @@ -409,7 +407,6 @@ class Auth(object): ) user_info = { "admin": bool(ret.get("admin", False)), - "device_id": ret.get("device_id"), "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), } diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 1c9e7152c7..d852a18555 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -34,6 +34,7 @@ class AdminHandler(BaseHandler): d = {} for r in res: + # Note that device_id is always None device = d.setdefault(r["device_id"], {}) session = device.setdefault(r["access_token"], []) session.append({ diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index f12465fa2c..23b779ad7c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -183,7 +183,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def create_and_send_event(self, event_dict, ratelimit=True, - client=None, txn_id=None): + token_id=None, txn_id=None): """ Given a dict from a client, create and handle a new event. Creates an FrozenEvent object, filling out auth_events, prev_events, @@ -217,11 +217,8 @@ class MessageHandler(BaseHandler): builder.content ) - if client is not None: - if client.token_id is not None: - builder.internal_metadata.token_id = client.token_id - if client.device_id is not None: - builder.internal_metadata.device_id = client.device_id + if token_id is not None: + builder.internal_metadata.token_id = token_id if txn_id is not None: builder.internal_metadata.txn_id = txn_id diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 2ce754b028..504b63eab4 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(auth_user) if not is_admin and target_user != auth_user: diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 6758a888b3..4dcda57c1b 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -69,7 +69,7 @@ class ClientDirectoryServer(ClientV1RestServlet): try: # try to auth as a user - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) try: user_id = user.to_string() yield dir_handler.create_association( @@ -116,7 +116,7 @@ class ClientDirectoryServer(ClientV1RestServlet): # fallback to default user behaviour if they aren't an AS pass - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(user) if not is_admin: diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 77b7c25a03..582148b659 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) try: handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) @@ -71,7 +71,7 @@ class EventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, event_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.event_handler event = yield handler.get_event(auth_user, event_id) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 4a259bba64..4ea4da653c 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) with_feedback = "feedback" in request.args as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 78d4f2b128..a770efd841 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = yield self.handlers.presence_handler.get_state( @@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = {} @@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): @@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 1e77eb49cf..fdde88a60d 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: @@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index c83287c028..3aabc93b8b 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -65,7 +65,7 @@ class PusherRestServlet(ClientV1RestServlet): try: yield pusher_pool.add_pusher( user_name=user.to_string(), - access_token=client.token_id, + access_token=token_id, profile_tag=content['profile_tag'], kind=content['kind'], app_id=content['app_id'], diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index b4a70cba99..c9c27dd5a0 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) room_config = self.get_room_config(request) info = yield self.make_room(room_config, auth_user, None) @@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( @@ -143,7 +143,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -159,7 +159,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): msg_handler = self.handlers.message_handler yield msg_handler.create_and_send_event( - event_dict, client=client, txn_id=txn_id, + event_dict, token_id=token_id, txn_id=txn_id, ) defer.returnValue((200, {})) @@ -175,7 +175,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_type, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -186,7 +186,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): "room_id": room_id, "sender": user.to_string(), }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -220,7 +220,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) # the identifier could be a room alias or a room id. Try one then the # other if it fails to parse, without swallowing other valid @@ -250,7 +250,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): "sender": user.to_string(), "state_key": user.to_string(), }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -289,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.room_member_handler members = yield handler.get_room_members_as_pagination_chunk( room_id=room_id, @@ -317,7 +317,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request( request, default_limit=10, ) @@ -341,7 +341,7 @@ class RoomStateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.message_handler # Get all the current state for this room events = yield handler.get_state_events( @@ -357,7 +357,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request(request) content = yield self.handlers.message_handler.room_initial_sync( room_id=room_id, @@ -402,7 +402,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, membership_action, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -427,7 +427,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): "sender": user.to_string(), "state_key": state_key, }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -457,7 +457,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_id, txn_id=None): - user, client = yield self.auth.get_user_by_req(request) + user, token_id = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -469,7 +469,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): "sender": user.to_string(), "redacts": event_id, }, - client=client, + token_id=token_id, txn_id=txn_id, ) @@ -497,7 +497,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) room_id = urllib.unquote(room_id) target_user = UserID.from_string(urllib.unquote(user_id)) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 11d08fbced..4ae2d81b70 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 522a312c9e..b5edffdb60 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -55,7 +55,7 @@ class PasswordRestServlet(RestServlet): if LoginType.PASSWORD in result: # if using password, they should also be logged in - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) if auth_user.to_string() != result[LoginType.PASSWORD]: raise LoginError(400, "", Codes.UNKNOWN) user_id = auth_user.to_string() @@ -119,7 +119,7 @@ class ThreepidRestServlet(RestServlet): raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) threePidCreds = body['threePidCreds'] - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 703250cea8..f8f91b63f5 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -40,7 +40,7 @@ class GetFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, filter_id): target_user = UserID.from_string(user_id) - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) if target_user != auth_user: raise AuthError(403, "Cannot get filters for other users") @@ -76,7 +76,7 @@ class CreateFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) if target_user != auth_user: raise AuthError(403, "Cannot create filters for other users") diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 718928eedd..ec1145454f 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -63,7 +63,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, device_id): - auth_user, client_info = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() # TODO: Check that the device_id matches that in the authentication # or derive the device_id from the authentication instead. @@ -108,7 +108,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, device_id): - auth_user, client_info = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() result = yield self.store.count_e2e_one_time_keys(user_id, device_id) @@ -180,7 +180,7 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): - auth_user, client_info = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) auth_user_id = auth_user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 40406e2ede..52e99f54d5 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -39,7 +39,7 @@ class ReceiptRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, receipt_type, event_id): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) yield self.receipts_handler.received_client_receipt( room_id, diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index f2fd0b9f32..83a257b969 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -87,7 +87,7 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, client = yield self.auth.get_user_by_req(request) + user, _ = yield self.auth.get_user_by_req(request) timeout = parse_integer(request, "timeout", default=0) limit = parse_integer(request, "limit", required=True) diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index e77a20fb2e..c28dc86cd7 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource): @defer.inlineCallbacks def map_request_to_name(self, request): # auth the user - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) # namespace all file uploads on the user prefix = base64.urlsafe_b64encode( diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index cdd1d44e07..439d5a30a8 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource): @request_handler @defer.inlineCallbacks def _async_render_POST(self, request): - auth_user, client = yield self.auth.get_user_by_req(request) + auth_user, _ = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point content_length = request.getHeader("Content-Length") diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 53673b3bf5..77cb1dbd81 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -94,9 +94,9 @@ class DataStore(RoomMemberStore, RoomStore, ) @defer.inlineCallbacks - def insert_client_ip(self, user, access_token, device_id, ip, user_agent): + def insert_client_ip(self, user, access_token, ip, user_agent): now = int(self._clock.time_msec()) - key = (user.to_string(), access_token, device_id, ip) + key = (user.to_string(), access_token, ip) try: last_seen = self.client_ip_last_seen.get(key) @@ -120,7 +120,6 @@ class DataStore(RoomMemberStore, RoomStore, "user_agent": user_agent, }, values={ - "device_id": device_id, "last_seen": now, }, desc="insert_client_ip", @@ -132,7 +131,7 @@ class DataStore(RoomMemberStore, RoomStore, table="user_ips", keyvalues={"user_id": user.to_string()}, retcols=[ - "device_id", "access_token", "ip", "user_agent", "last_seen" + "access_token", "ip", "user_agent", "last_seen" ], desc="get_user_ip_and_agents", ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index f632306688..240d14c4d0 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -163,7 +163,7 @@ class RegistrationStore(SQLBaseStore): Args: token (str): The access token of a user. Returns: - dict: Including the name (user_id), device_id and whether they are + dict: Including the name (user_id) and whether they are an admin. Raises: StoreError if no user was found. @@ -228,8 +228,7 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.admin," - " access_tokens.device_id, access_tokens.id as token_id" + "SELECT users.name, users.admin, access_tokens.id as token_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" diff --git a/synapse/types.py b/synapse/types.py index e190374cbd..9cffc33d27 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -209,7 +209,3 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): return "t%d-%d" % (self.topological, self.stream) else: return "s%d" % (self.stream,) - - -# token_id is the primary key ID of the access token, not the access token itself. -ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id")) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 3343c635cc..777eb0395e 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -40,7 +40,6 @@ class AuthTestCase(unittest.TestCase): self.store.get_app_service_by_token = Mock(return_value=None) user_info = { "name": self.test_user, - "device_id": "nothing", "token_id": "ditto", "admin": False } @@ -49,7 +48,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, info) = yield self.auth.get_user_by_req(request) + (user, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): @@ -66,7 +65,6 @@ class AuthTestCase(unittest.TestCase): self.store.get_app_service_by_token = Mock(return_value=None) user_info = { "name": self.test_user, - "device_id": "nothing", "token_id": "ditto", "admin": False } @@ -86,7 +84,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, info) = yield self.auth.get_user_by_req(request) + (user, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), self.test_user) def test_get_user_by_req_appservice_bad_token(self): @@ -121,7 +119,7 @@ class AuthTestCase(unittest.TestCase): request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, info) = yield self.auth.get_user_by_req(request) + (user, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), masquerading_user_id) def test_get_user_by_req_appservice_valid_token_bad_user_id(self): diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 0b78a82a66..4039a86d85 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -74,7 +74,6 @@ class PresenceStateTestCase(unittest.TestCase): return { "user": UserID.from_string(myid), "admin": False, - "device_id": None, "token_id": 1, } @@ -163,7 +162,6 @@ class PresenceListTestCase(unittest.TestCase): return { "user": UserID.from_string(myid), "admin": False, - "device_id": None, "token_id": 1, } diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2e55cc08a1..dd1e67e0f9 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -58,7 +58,6 @@ class RoomPermissionsTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -445,7 +444,6 @@ class RoomsMemberListTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -525,7 +523,6 @@ class RoomsCreateTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -618,7 +615,6 @@ class RoomTopicTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } @@ -725,7 +721,6 @@ class RoomMemberStateTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -852,7 +847,6 @@ class RoomMessagesTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -949,7 +943,6 @@ class RoomInitialSyncTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index dc8bbaaf0e..0f70ce81dc 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -65,7 +65,6 @@ class RoomTypingTestCase(RestTestCase): return { "user": UserID.from_string(self.auth_user_id), "admin": False, - "device_id": None, "token_id": 1, } diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index 15568b36cd..badb59f080 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -47,7 +47,6 @@ class V2AlphaRestTestCase(unittest.TestCase): return { "user": UserID.from_string(self.USER_ID), "admin": False, - "device_id": None, "token_id": 1, } hs.get_auth().get_user_by_access_token = _get_user_by_access_token diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index a4f929796a..54fe10d58f 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -54,7 +54,6 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertDictContainsSubset( { "admin": 0, - "device_id": None, "name": self.user_id, }, result @@ -72,7 +71,6 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertDictContainsSubset( { "admin": 0, - "device_id": None, "name": self.user_id, }, result diff --git a/tests/utils.py b/tests/utils.py index d0fba2252d..ff560ef342 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -282,7 +282,6 @@ class MemoryDataStore(object): return { "name": self.tokens_to_users[token], "admin": 0, - "device_id": None, } except: raise StoreError(400, "User does not exist.") @@ -380,7 +379,7 @@ class MemoryDataStore(object): def get_ops_levels(self, room_id): return defer.succeed((5, 5, 5)) - def insert_client_ip(self, user, device_id, access_token, ip, user_agent): + def insert_client_ip(self, user, access_token, ip, user_agent): return defer.succeed(None) -- cgit 1.5.1 From a9d8bd95e722e24c7ddd6b14a3714c1b2f737d4d Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 25 Aug 2015 16:29:39 +0100 Subject: Stop looking up "admin", which we never read --- synapse/api/auth.py | 4 +--- synapse/storage/registration.py | 5 ++--- tests/api/test_auth.py | 2 -- tests/rest/client/v1/test_presence.py | 2 -- tests/rest/client/v1/test_rooms.py | 7 ------- tests/rest/client/v1/test_typing.py | 1 - tests/rest/client/v2_alpha/__init__.py | 1 - tests/storage/test_registration.py | 6 ++---- tests/utils.py | 1 - 9 files changed, 5 insertions(+), 24 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index b41e34e658..65ee1452ce 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -392,8 +392,7 @@ class Auth(object): Args: token (str): The access token to get the user by. Returns: - dict : dict that includes the user and whether the - user is a server admin. + dict : dict that includes the user and the ID of their access token. Raises: AuthError if no user by that token exists or the token is invalid. """ @@ -404,7 +403,6 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) user_info = { - "admin": bool(ret.get("admin", False)), "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), } diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 240d14c4d0..a2d0f7c4b1 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -163,8 +163,7 @@ class RegistrationStore(SQLBaseStore): Args: token (str): The access token of a user. Returns: - dict: Including the name (user_id) and whether they are - an admin. + dict: Including the name (user_id) and the ID of their access token. Raises: StoreError if no user was found. """ @@ -228,7 +227,7 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.admin, access_tokens.id as token_id" + "SELECT users.name, access_tokens.id as token_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 777eb0395e..22fc804331 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -41,7 +41,6 @@ class AuthTestCase(unittest.TestCase): user_info = { "name": self.test_user, "token_id": "ditto", - "admin": False } self.store.get_user_by_access_token = Mock(return_value=user_info) @@ -66,7 +65,6 @@ class AuthTestCase(unittest.TestCase): user_info = { "name": self.test_user, "token_id": "ditto", - "admin": False } self.store.get_user_by_access_token = Mock(return_value=user_info) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 4039a86d85..91547bdd06 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -73,7 +73,6 @@ class PresenceStateTestCase(unittest.TestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(myid), - "admin": False, "token_id": 1, } @@ -161,7 +160,6 @@ class PresenceListTestCase(unittest.TestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(myid), - "admin": False, "token_id": 1, } diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index dd1e67e0f9..34ab47d02e 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -57,7 +57,6 @@ class RoomPermissionsTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -443,7 +442,6 @@ class RoomsMemberListTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -522,7 +520,6 @@ class RoomsCreateTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -614,7 +611,6 @@ class RoomTopicTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } @@ -720,7 +716,6 @@ class RoomMemberStateTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -846,7 +841,6 @@ class RoomMessagesTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token @@ -942,7 +936,6 @@ class RoomInitialSyncTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 0f70ce81dc..1c4519406d 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -64,7 +64,6 @@ class RoomTypingTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), - "admin": False, "token_id": 1, } diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index badb59f080..ef972a53aa 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -46,7 +46,6 @@ class V2AlphaRestTestCase(unittest.TestCase): def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.USER_ID), - "admin": False, "token_id": 1, } hs.get_auth().get_user_by_access_token = _get_user_by_access_token diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 54fe10d58f..0cce6c37df 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -53,8 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertDictContainsSubset( { - "admin": 0, - "name": self.user_id, + "name": self.user_id, }, result ) @@ -70,8 +69,7 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertDictContainsSubset( { - "admin": 0, - "name": self.user_id, + "name": self.user_id, }, result ) diff --git a/tests/utils.py b/tests/utils.py index ff560ef342..3766a994f2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -281,7 +281,6 @@ class MemoryDataStore(object): try: return { "name": self.tokens_to_users[token], - "admin": 0, } except: raise StoreError(400, "User does not exist.") -- cgit 1.5.1 From 6a4b650d8ad3e6c095020cac3861e430d643d53d Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Wed, 26 Aug 2015 13:22:23 +0100 Subject: Attempt to validate macaroons A couple of weird caveats: * If we can't validate your macaroon, we fall back to checking that your access token is in the DB, and ignoring the failure * Even if we can validate your macaroon, we still have to hit the DB to get the access token ID, which we pretend is a device ID all over the codebase. This mostly adds the interesting code, and points out the two pieces we need to delete (and necessary conditions) in order to fix the above caveats. --- synapse/api/auth.py | 104 +++++++++++++++++++++--- tests/api/test_auth.py | 142 ++++++++++++++++++++++++++++++++- tests/rest/client/v1/test_presence.py | 8 +- tests/rest/client/v1/test_rooms.py | 28 +++---- tests/rest/client/v1/test_typing.py | 4 +- tests/rest/client/v1/utils.py | 3 - tests/rest/client/v2_alpha/__init__.py | 4 +- 7 files changed, 257 insertions(+), 36 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 65ee1452ce..f8ea1e2c69 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -23,6 +23,7 @@ from synapse.util.logutils import log_function from synapse.types import UserID import logging +import pymacaroons logger = logging.getLogger(__name__) @@ -40,6 +41,12 @@ class Auth(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 + self._KNOWN_CAVEAT_PREFIXES = set([ + "gen = ", + "type = ", + "time < ", + "user_id = ", + ]) def check(self, event, auth_events): """ Checks if this event is correctly authed. @@ -359,8 +366,8 @@ class Auth(object): except KeyError: pass # normal users won't have the user_id query parameter set. - user_info = yield self.get_user_by_access_token(access_token) - user = user_info["user"] + user_info = yield self._get_user_by_access_token(access_token) + user_id = user_info["user_id"] token_id = user_info["token_id"] ip_addr = self.hs.get_ip_from_request(request) @@ -368,17 +375,17 @@ class Auth(object): "User-Agent", default=[""] )[0] - if user and access_token and ip_addr: + if user_id and access_token and ip_addr: self.store.insert_client_ip( - user=user, + user=user_id, access_token=access_token, ip=ip_addr, user_agent=user_agent ) - request.authenticated_entity = user.to_string() + request.authenticated_entity = user_id.to_string() - defer.returnValue((user, token_id,)) + defer.returnValue((user_id, token_id,)) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", @@ -386,7 +393,7 @@ class Auth(object): ) @defer.inlineCallbacks - def get_user_by_access_token(self, token): + def _get_user_by_access_token(self, token): """ Get a registered user's ID. Args: @@ -396,6 +403,86 @@ class Auth(object): Raises: AuthError if no user by that token exists or the token is invalid. """ + try: + ret = yield self._get_user_from_macaroon(token) + except AuthError: + # TODO(daniel): Remove this fallback when all existing access tokens + # have been re-issued as macaroons. + ret = yield self._look_up_user_by_access_token(token) + defer.returnValue(ret) + + @defer.inlineCallbacks + def _get_user_from_macaroon(self, macaroon_str): + try: + macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) + self._validate_macaroon(macaroon) + + user_prefix = "user_id = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(user_prefix): + user_id = UserID.from_string(caveat.caveat_id[len(user_prefix):]) + # This codepath exists so that we can actually return a + # token ID, because we use token IDs in place of device + # identifiers throughout the codebase. + # TODO(daniel): Remove this fallback when device IDs are + # properly implemented. + ret = yield self._look_up_user_by_access_token(macaroon_str) + if ret["user_id"] != user_id: + logger.error( + "Macaroon user (%s) != DB user (%s)", + user_id, + ret["user_id"] + ) + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "User mismatch in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + defer.returnValue(ret) + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", + errcode=Codes.UNKNOWN_TOKEN + ) + + def _validate_macaroon(self, macaroon): + v = pymacaroons.Verifier() + v.satisfy_exact("gen = 1") + v.satisfy_exact("type = access") + v.satisfy_general(lambda c: c.startswith("user_id = ")) + v.satisfy_general(self._verify_expiry) + v.verify(macaroon, self.hs.config.macaroon_secret_key) + + v = pymacaroons.Verifier() + v.satisfy_general(self._verify_recognizes_caveats) + v.verify(macaroon, self.hs.config.macaroon_secret_key) + + def _verify_expiry(self, caveat): + prefix = "time < " + if not caveat.startswith(prefix): + return False + # TODO(daniel): Enable expiry check when clients actually know how to + # refresh tokens. (And remember to enable the tests) + return True + expiry = int(caveat[len(prefix):]) + now = self.hs.get_clock().time_msec() + return now < expiry + + def _verify_recognizes_caveats(self, caveat): + first_space = caveat.find(" ") + if first_space < 0: + return False + second_space = caveat.find(" ", first_space + 1) + if second_space < 0: + return False + return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES + + @defer.inlineCallbacks + def _look_up_user_by_access_token(self, token): ret = yield self.store.get_user_by_access_token(token) if not ret: raise AuthError( @@ -403,10 +490,9 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) user_info = { - "user": UserID.from_string(ret.get("name")), + "user_id": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), } - defer.returnValue(user_info) @defer.inlineCallbacks diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 22fc804331..1ba85d6f83 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -14,22 +14,27 @@ # limitations under the License. from tests import unittest from twisted.internet import defer +from twisted.trial.unittest import FailTest from mock import Mock from synapse.api.auth import Auth from synapse.api.errors import AuthError +from synapse.types import UserID +from tests.utils import setup_test_homeserver + +import pymacaroons class AuthTestCase(unittest.TestCase): + @defer.inlineCallbacks def setUp(self): self.state_handler = Mock() self.store = Mock() - self.hs = Mock() + self.hs = yield setup_test_homeserver(handlers=None) self.hs.get_datastore = Mock(return_value=self.store) - self.hs.get_state_handler = Mock(return_value=self.state_handler) self.auth = Auth(self.hs) self.test_user = "@foo:bar" @@ -133,3 +138,136 @@ class AuthTestCase(unittest.TestCase): request.requestHeaders.getRawHeaders = Mock(return_value=[""]) d = self.auth.get_user_by_req(request) self.failureResultOf(d, AuthError) + + @defer.inlineCallbacks + def test_get_user_from_macaroon(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize()) + user_id = user_info["user_id"] + self.assertEqual(UserID.from_string(user), user_id) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_user_db_mismatch(self): + self.store.get_user_by_access_token = Mock( + return_value={"name": "@percy:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("User mismatch", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_missing_caveat(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("No user caveat", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_wrong_key(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key + "wrong") + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("Invalid macaroon", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_unknown_caveat(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + macaroon.add_first_party_caveat("cunning > fox") + + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("Invalid macaroon", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_expired(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + self.todo = (FailTest, "Token expiry isn't currently enabled",) + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user,)) + macaroon.add_first_party_caveat("time < 1") # ms + + self.hs.clock.now = 5000 # seconds + with self.assertRaises(AuthError) as cm: + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + self.assertEqual(401, cm.exception.code) + self.assertIn("Invalid macaroon", cm.exception.msg) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 91547bdd06..d8d1416f59 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -72,11 +72,11 @@ class PresenceStateTestCase(unittest.TestCase): def _get_user_by_access_token(token=None): return { - "user": UserID.from_string(myid), + "user_id": UserID.from_string(myid), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token room_member_handler = hs.handlers.room_member_handler = Mock( spec=[ @@ -159,7 +159,7 @@ class PresenceListTestCase(unittest.TestCase): def _get_user_by_access_token(token=None): return { - "user": UserID.from_string(myid), + "user_id": UserID.from_string(myid), "token_id": 1, } @@ -169,7 +169,7 @@ class PresenceListTestCase(unittest.TestCase): ] ) - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token presence.register_servlets(hs, self.mock_resource) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 34ab47d02e..be1d52f720 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -56,10 +56,10 @@ class RoomPermissionsTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user": UserID.from_string(self.auth_user_id), + "user_id": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -441,10 +441,10 @@ class RoomsMemberListTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user": UserID.from_string(self.auth_user_id), + "user_id": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -519,10 +519,10 @@ class RoomsCreateTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user": UserID.from_string(self.auth_user_id), + "user_id": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -610,11 +610,11 @@ class RoomTopicTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user": UserID.from_string(self.auth_user_id), + "user_id": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -715,10 +715,10 @@ class RoomMemberStateTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user": UserID.from_string(self.auth_user_id), + "user_id": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -840,10 +840,10 @@ class RoomMessagesTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user": UserID.from_string(self.auth_user_id), + "user_id": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -935,10 +935,10 @@ class RoomInitialSyncTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user": UserID.from_string(self.auth_user_id), + "user_id": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 1c4519406d..da6fc975f7 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -63,11 +63,11 @@ class RoomTypingTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user": UserID.from_string(self.auth_user_id), + "user_id": UserID.from_string(self.auth_user_id), "token_id": 1, } - hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index c472d53043..85096a0326 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -37,9 +37,6 @@ class RestTestCase(unittest.TestCase): self.mock_resource = None self.auth_user_id = None - def mock_get_user_by_access_token(self, token=None): - return self.auth_user_id - @defer.inlineCallbacks def create_room_as(self, room_creator, is_public=True, tok=None): temp_id = self.auth_user_id diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index ef972a53aa..7d0f77a3ee 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -45,10 +45,10 @@ class V2AlphaRestTestCase(unittest.TestCase): def _get_user_by_access_token(token=None): return { - "user": UserID.from_string(self.USER_ID), + "user_id": UserID.from_string(self.USER_ID), "token_id": 1, } - hs.get_auth().get_user_by_access_token = _get_user_by_access_token + hs.get_auth()._get_user_by_access_token = _get_user_by_access_token for r in self.TO_REGISTER: r.register_servlets(hs, self.mock_resource) -- cgit 1.5.1 From 81450fded8c4d2a0f4a914251cc2d11a366efdbd Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Wed, 26 Aug 2015 13:56:01 +0100 Subject: Turn TODO into thing which actually will fail --- tests/api/test_auth.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 1ba85d6f83..2e2d0c428a 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -14,7 +14,6 @@ # limitations under the License. from tests import unittest from twisted.internet import defer -from twisted.trial.unittest import FailTest from mock import Mock @@ -251,7 +250,6 @@ class AuthTestCase(unittest.TestCase): return_value={"name": "@baldrick:matrix.org"} ) - self.todo = (FailTest, "Token expiry isn't currently enabled",) self.store.get_user_by_access_token = Mock( return_value={"name": "@baldrick:matrix.org"} ) @@ -267,7 +265,12 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("time < 1") # ms self.hs.clock.now = 5000 # seconds - with self.assertRaises(AuthError) as cm: - yield self.auth._get_user_from_macaroon(macaroon.serialize()) - self.assertEqual(401, cm.exception.code) - self.assertIn("Invalid macaroon", cm.exception.msg) + + yield self.auth._get_user_from_macaroon(macaroon.serialize()) + # TODO(daniel): Turn on the check that we validate expiration, when we + # validate expiration (and remove the above line, which will start + # throwing). + # with self.assertRaises(AuthError) as cm: + # yield self.auth._get_user_from_macaroon(macaroon.serialize()) + # self.assertEqual(401, cm.exception.code) + # self.assertIn("Invalid macaroon", cm.exception.msg) -- cgit 1.5.1 From 3063383547529a542b48f416d64fd98eaf6a2f60 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Wed, 26 Aug 2015 15:59:32 +0100 Subject: Swap out bcrypt for md5 in tests This reduces our ~8 second sequential test time down to ~7 seconds --- synapse/handlers/auth.py | 27 +++++++++++++++++++++++++-- synapse/handlers/register.py | 2 +- tests/utils.py | 13 +++++++++++++ 3 files changed, 39 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 1ab19cd1a6..59f687e0f1 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -324,7 +324,7 @@ class AuthHandler(BaseHandler): def _check_password(self, user_id, password, stored_hash): """Checks that user_id has passed password, raises LoginError if not.""" - if not bcrypt.checkpw(password, stored_hash): + if not self.validate_hash(password, stored_hash): logger.warn("Failed password login for user %s", user_id) raise LoginError(403, "", errcode=Codes.FORBIDDEN) @@ -369,7 +369,7 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def set_password(self, user_id, newpassword): - password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) + password_hash = self.hash(newpassword) yield self.store.user_set_password_hash(user_id, password_hash) yield self.store.user_delete_access_tokens(user_id) @@ -391,3 +391,26 @@ class AuthHandler(BaseHandler): def _remove_session(self, session): logger.debug("Removing session %s", session) del self.sessions[session["id"]] + + def hash(self, password): + """Computes a secure hash of password. + + Args: + password (str): Password to hash. + + Returns: + Hashed password (str). + """ + return bcrypt.hashpw(password, bcrypt.gensalt()) + + def validate_hash(self, password, stored_hash): + """Validates that self.hash(password) == stored_hash. + + Args: + password (str): Password to hash. + stored_hash (str): Expected hash value. + + Returns: + Whether self.hash(password) == stored_hash (bool). + """ + return bcrypt.checkpw(password, stored_hash) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 56d125f753..855bb58522 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -82,7 +82,7 @@ class RegistrationHandler(BaseHandler): yield run_on_reactor() password_hash = None if password: - password_hash = bcrypt.hashpw(password, bcrypt.gensalt()) + password_hash = self.auth_handler().hash(password) if localpart: yield self.check_username(localpart) diff --git a/tests/utils.py b/tests/utils.py index 3766a994f2..dd19a16fc7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -27,6 +27,7 @@ from twisted.enterprise.adbapi import ConnectionPool from collections import namedtuple from mock import patch, Mock +import hashlib import urllib import urlparse @@ -67,6 +68,18 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): **kargs ) + # bcrypt is far too slow to be doing in unit tests + def swap_out_hash_for_testing(old_build_handlers): + def build_handlers(): + handlers = old_build_handlers() + auth_handler = handlers.auth_handler + auth_handler.hash = lambda p: hashlib.md5(p).hexdigest() + auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h + return handlers + return build_handlers + + hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers) + defer.returnValue(hs) -- cgit 1.5.1 From e255c2c32ff85db03abbf2dac184b2949f481cfb Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 1 Sep 2015 12:41:16 +0100 Subject: s/user_id/user/g for consistency --- synapse/api/auth.py | 20 ++++++++++---------- tests/api/test_auth.py | 8 ++++---- tests/rest/client/v1/test_presence.py | 4 ++-- tests/rest/client/v1/test_rooms.py | 14 +++++++------- tests/rest/client/v1/test_typing.py | 2 +- tests/rest/client/v2_alpha/__init__.py | 2 +- 6 files changed, 25 insertions(+), 25 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index f8ea1e2c69..0a77a76cb8 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -367,7 +367,7 @@ class Auth(object): pass # normal users won't have the user_id query parameter set. user_info = yield self._get_user_by_access_token(access_token) - user_id = user_info["user_id"] + user = user_info["user"] token_id = user_info["token_id"] ip_addr = self.hs.get_ip_from_request(request) @@ -375,17 +375,17 @@ class Auth(object): "User-Agent", default=[""] )[0] - if user_id and access_token and ip_addr: + if user and access_token and ip_addr: self.store.insert_client_ip( - user=user_id, + user=user, access_token=access_token, ip=ip_addr, user_agent=user_agent ) - request.authenticated_entity = user_id.to_string() + request.authenticated_entity = user.to_string() - defer.returnValue((user_id, token_id,)) + defer.returnValue((user, token_id,)) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", @@ -420,18 +420,18 @@ class Auth(object): user_prefix = "user_id = " for caveat in macaroon.caveats: if caveat.caveat_id.startswith(user_prefix): - user_id = UserID.from_string(caveat.caveat_id[len(user_prefix):]) + user = UserID.from_string(caveat.caveat_id[len(user_prefix):]) # This codepath exists so that we can actually return a # token ID, because we use token IDs in place of device # identifiers throughout the codebase. # TODO(daniel): Remove this fallback when device IDs are # properly implemented. ret = yield self._look_up_user_by_access_token(macaroon_str) - if ret["user_id"] != user_id: + if ret["user"] != user: logger.error( "Macaroon user (%s) != DB user (%s)", - user_id, - ret["user_id"] + user, + ret["user"] ) raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, @@ -490,7 +490,7 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) user_info = { - "user_id": UserID.from_string(ret.get("name")), + "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), } defer.returnValue(user_info) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 2e2d0c428a..c96273480d 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -146,17 +146,17 @@ class AuthTestCase(unittest.TestCase): return_value={"name": "@baldrick:matrix.org"} ) - user = "@baldrick:matrix.org" + user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", key=self.hs.config.macaroon_secret_key) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") - macaroon.add_first_party_caveat("user_id = %s" % (user,)) + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize()) - user_id = user_info["user_id"] - self.assertEqual(UserID.from_string(user), user_id) + user = user_info["user"] + self.assertEqual(UserID.from_string(user_id), user) @defer.inlineCallbacks def test_get_user_from_macaroon_user_db_mismatch(self): diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index d8d1416f59..2ee3da0b34 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -72,7 +72,7 @@ class PresenceStateTestCase(unittest.TestCase): def _get_user_by_access_token(token=None): return { - "user_id": UserID.from_string(myid), + "user": UserID.from_string(myid), "token_id": 1, } @@ -159,7 +159,7 @@ class PresenceListTestCase(unittest.TestCase): def _get_user_by_access_token(token=None): return { - "user_id": UserID.from_string(myid), + "user": UserID.from_string(myid), "token_id": 1, } diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index be1d52f720..9fb2bfb315 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -56,7 +56,7 @@ class RoomPermissionsTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user_id": UserID.from_string(self.auth_user_id), + "user": UserID.from_string(self.auth_user_id), "token_id": 1, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -441,7 +441,7 @@ class RoomsMemberListTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user_id": UserID.from_string(self.auth_user_id), + "user": UserID.from_string(self.auth_user_id), "token_id": 1, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -519,7 +519,7 @@ class RoomsCreateTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user_id": UserID.from_string(self.auth_user_id), + "user": UserID.from_string(self.auth_user_id), "token_id": 1, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -610,7 +610,7 @@ class RoomTopicTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user_id": UserID.from_string(self.auth_user_id), + "user": UserID.from_string(self.auth_user_id), "token_id": 1, } @@ -715,7 +715,7 @@ class RoomMemberStateTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user_id": UserID.from_string(self.auth_user_id), + "user": UserID.from_string(self.auth_user_id), "token_id": 1, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -840,7 +840,7 @@ class RoomMessagesTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user_id": UserID.from_string(self.auth_user_id), + "user": UserID.from_string(self.auth_user_id), "token_id": 1, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -935,7 +935,7 @@ class RoomInitialSyncTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user_id": UserID.from_string(self.auth_user_id), + "user": UserID.from_string(self.auth_user_id), "token_id": 1, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index da6fc975f7..6395ce79db 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -63,7 +63,7 @@ class RoomTypingTestCase(RestTestCase): def _get_user_by_access_token(token=None): return { - "user_id": UserID.from_string(self.auth_user_id), + "user": UserID.from_string(self.auth_user_id), "token_id": 1, } diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index 7d0f77a3ee..f45570a1c0 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -45,7 +45,7 @@ class V2AlphaRestTestCase(unittest.TestCase): def _get_user_by_access_token(token=None): return { - "user_id": UserID.from_string(self.USER_ID), + "user": UserID.from_string(self.USER_ID), "token_id": 1, } hs.get_auth()._get_user_by_access_token = _get_user_by_access_token -- cgit 1.5.1 From 00149c063b8f81548bd3eefd3e497acc03512d35 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 1 Sep 2015 15:42:03 +0100 Subject: Fix tests --- synapse/api/auth.py | 2 +- tests/test_state.py | 37 +++++++++++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 410f4c11e7..df7fb6aab7 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -69,7 +69,7 @@ class Auth(object): if not creation_event: raise SynapseError( - 400, + 403, "Room %r does not exist" % (event.room_id,) ) diff --git a/tests/test_state.py b/tests/test_state.py index 5845358754..55f37c521f 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -204,8 +204,8 @@ class StateTestCase(unittest.TestCase): nodes={ "START": DictObj( type=EventTypes.Create, - state_key="creator", - content={"membership": "@user_id:example.com"}, + state_key="", + content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( @@ -259,8 +259,8 @@ class StateTestCase(unittest.TestCase): nodes={ "START": DictObj( type=EventTypes.Create, - state_key="creator", - content={"membership": "@user_id:example.com"}, + state_key="", + content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( @@ -432,13 +432,19 @@ class StateTestCase(unittest.TestCase): def test_resolve_message_conflict(self): event = create_event(type="test_message", name="event") + creation = create_event( + type=EventTypes.Create, state_key="" + ) + old_state_1 = [ + creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ + creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), @@ -446,7 +452,7 @@ class StateTestCase(unittest.TestCase): context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(len(context.current_state), 5) + self.assertEqual(len(context.current_state), 6) self.assertIsNone(context.state_group) @@ -454,13 +460,19 @@ class StateTestCase(unittest.TestCase): def test_resolve_state_conflict(self): event = create_event(type="test4", state_key="", name="event") + creation = create_event( + type=EventTypes.Create, state_key="" + ) + old_state_1 = [ + creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ + creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), @@ -468,7 +480,7 @@ class StateTestCase(unittest.TestCase): context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(len(context.current_state), 5) + self.assertEqual(len(context.current_state), 6) self.assertIsNone(context.state_group) @@ -484,36 +496,45 @@ class StateTestCase(unittest.TestCase): } ) + creation = create_event( + type=EventTypes.Create, state_key="", + content={"creator": "@foo:bar"} + ) + old_state_1 = [ + creation, member_event, create_event(type="test1", state_key="1", depth=1), ] old_state_2 = [ + creation, member_event, create_event(type="test1", state_key="1", depth=2), ] context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(old_state_2[1], context.current_state[("test1", "1")]) + self.assertEqual(old_state_2[2], context.current_state[("test1", "1")]) # Reverse the depth to make sure we are actually using the depths # during state resolution. old_state_1 = [ + creation, member_event, create_event(type="test1", state_key="1", depth=2), ] old_state_2 = [ + creation, member_event, create_event(type="test1", state_key="1", depth=1), ] context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(old_state_1[1], context.current_state[("test1", "1")]) + self.assertEqual(old_state_1[2], context.current_state[("test1", "1")]) def _get_context(self, event, old_state_1, old_state_2): group_name_1 = "group_name_1" -- cgit 1.5.1 From b345853918b9300bdde19010d29bf66973497de7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 1 Sep 2015 15:57:35 +0100 Subject: Check against sender rather than event_id --- synapse/api/auth.py | 6 +++--- tests/test_state.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index f7cf17e433..75b7c467b5 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError from synapse.util.logutils import log_function -from synapse.types import EventID, RoomID, UserID +from synapse.types import RoomID, UserID import logging @@ -66,10 +66,10 @@ class Auth(object): return True creating_domain = RoomID.from_string(event.room_id).domain - originating_domain = EventID.from_string(event.event_id).domain + originating_domain = UserID.from_string(event.sender).domain if creating_domain != originating_domain: if not self.can_federate(event, auth_events): - raise SynapseError( + raise AuthError( 403, "This room has been marked as unfederatable." ) diff --git a/tests/test_state.py b/tests/test_state.py index 5845358754..04c4439183 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -35,7 +35,7 @@ def create_event(name=None, type=None, state_key=None, depth=2, event_id=None, if not event_id: _next_event_id += 1 - event_id = str(_next_event_id) + event_id = "$%s:test" % (_next_event_id,) if not name: if state_key is not None: -- cgit 1.5.1 From bc8b25eb56bf4fcec3546c2ea28741189a519da5 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 9 Sep 2015 15:42:16 +0100 Subject: Allow users that have left the room to view the member list from the point they left --- synapse/handlers/room.py | 36 ------------------------------------ synapse/rest/client/v1/room.py | 18 +++++++++++++----- tests/rest/client/v1/test_rooms.py | 4 ++-- 3 files changed, 15 insertions(+), 43 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index c5d1001b50..0ff816d53e 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -25,7 +25,6 @@ from synapse.api.constants import ( from synapse.api.errors import StoreError, SynapseError from synapse.util import stringutils, unwrapFirstError from synapse.util.async import run_on_reactor -from synapse.events.utils import serialize_event from collections import OrderedDict import logging @@ -342,41 +341,6 @@ class RoomMemberHandler(BaseHandler): if remotedomains is not None: remotedomains.add(member.domain) - @defer.inlineCallbacks - def get_room_members_as_pagination_chunk(self, room_id=None, user_id=None, - limit=0, start_tok=None, - end_tok=None): - """Retrieve a list of room members in the room. - - Args: - room_id (str): The room to get the member list for. - user_id (str): The ID of the user making the request. - limit (int): The max number of members to return. - start_tok (str): Optional. The start token if known. - end_tok (str): Optional. The end token if known. - Returns: - dict: A Pagination streamable dict. - Raises: - SynapseError if something goes wrong. - """ - yield self.auth.check_joined_room(room_id, user_id) - - member_list = yield self.store.get_room_members(room_id=room_id) - time_now = self.clock.time_msec() - event_list = [ - serialize_event(entry, time_now) - for entry in member_list - ] - chunk_data = { - "start": "START", # FIXME (erikj): START is no longer valid - "end": "END", - "chunk": event_list - } - # TODO honor Pagination stream params - # TODO snapshot this list to return on subsequent requests when - # paginating - defer.returnValue(chunk_data) - @defer.inlineCallbacks def change_membership(self, event, context, do_auth=True): """ Change the membership status of a user in a room. diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index c9c27dd5a0..f4558b95a7 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -290,12 +290,18 @@ class RoomMemberListRestServlet(ClientV1RestServlet): def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) user, _ = yield self.auth.get_user_by_req(request) - handler = self.handlers.room_member_handler - members = yield handler.get_room_members_as_pagination_chunk( + handler = self.handlers.message_handler + events = yield handler.get_state_events( room_id=room_id, - user_id=user.to_string()) + user_id=user.to_string(), + ) + + chunk = [] - for event in members["chunk"]: + for event in events: + if event["type"] != EventTypes.Member: + continue + chunk.append(event) # FIXME: should probably be state_key here, not user_id target_user = UserID.from_string(event["user_id"]) # Presence is an optional cache; don't fail if we can't fetch it @@ -308,7 +314,9 @@ class RoomMemberListRestServlet(ClientV1RestServlet): except: pass - defer.returnValue((200, members)) + defer.returnValue((200, { + "chunk": chunk + })) # TODO: Needs unit testing diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 34ab47d02e..d50cfe4298 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -492,9 +492,9 @@ class RoomsMemberListTestCase(RestTestCase): self.assertEquals(200, code, msg=str(response)) yield self.leave(room=room_id, user=self.user_id) - # can no longer see list, you've left. + # can see old list once left (code, response) = yield self.mock_resource.trigger_get(room_path) - self.assertEquals(403, code, msg=str(response)) + self.assertEquals(200, code, msg=str(response)) class RoomsCreateTestCase(RestTestCase): -- cgit 1.5.1 From e2054ce21a04f3d741293f50b283c01bbe2b0591 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 10 Sep 2015 15:06:47 +0100 Subject: Allow users to GET individual state events for rooms that they have left --- synapse/handlers/message.py | 20 +++++++++++++------- tests/rest/client/v1/test_rooms.py | 10 +++++----- 2 files changed, 18 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index db89491b46..5d18aaacf0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import RoomError, SynapseError +from synapse.api.errors import SynapseError from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator @@ -277,13 +277,19 @@ class MessageHandler(BaseHandler): Raises: SynapseError if something went wrong. """ - have_joined = yield self.auth.check_joined_room(room_id, user_id) - if not have_joined: - raise RoomError(403, "User not in room.") + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + + if member_event.membership == Membership.JOIN: + data = yield self.state_handler.get_current_state( + room_id, event_type, state_key + ) + elif member_event.membership == Membership.LEAVE: + key = (event_type, state_key) + room_state = yield self.store.get_state_for_events( + room_id, [member_event.event_id], [key] + ) + data = room_state[member_event.event_id].get(key) - data = yield self.state_handler.get_current_state( - room_id, event_type, state_key - ) defer.returnValue(data) @defer.inlineCallbacks diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index d50cfe4298..ed0ac8d5c8 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -239,7 +239,7 @@ class RoomPermissionsTestCase(RestTestCase): "PUT", topic_path, topic_content) self.assertEquals(403, code, msg=str(response)) (code, response) = yield self.mock_resource.trigger_get(topic_path) - self.assertEquals(403, code, msg=str(response)) + self.assertEquals(200, code, msg=str(response)) # get topic in PUBLIC room, not joined, expect 403 (code, response) = yield self.mock_resource.trigger_get( @@ -301,11 +301,11 @@ class RoomPermissionsTestCase(RestTestCase): room=room, expect_code=200) # get membership of self, get membership of other, private room + left - # expect all 403s + # expect all 200s yield self.leave(room=room, user=self.user_id) yield self._test_get_membership( members=[self.user_id, self.rmcreator_id], - room=room, expect_code=403) + room=room, expect_code=200) @defer.inlineCallbacks def test_membership_public_room_perms(self): @@ -326,11 +326,11 @@ class RoomPermissionsTestCase(RestTestCase): room=room, expect_code=200) # get membership of self, get membership of other, public room + left - # expect 403. + # expect 200. yield self.leave(room=room, user=self.user_id) yield self._test_get_membership( members=[self.user_id, self.rmcreator_id], - room=room, expect_code=403) + room=room, expect_code=200) @defer.inlineCallbacks def test_invited_permissions(self): -- cgit 1.5.1 From 7213588083dd9a721b0cd623fe22b308f25f19a5 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 22 Sep 2015 12:57:40 +0100 Subject: Implement configurable stats reporting SYN-287 This requires that HS owners either opt in or out of stats reporting. When --generate-config is passed, --report-stats must be specified If an already-generated config is used, and doesn't have the report_stats key, it is requested to be set. --- synapse/app/homeserver.py | 35 ++++++- synapse/app/synctl.py | 12 ++- synapse/config/_base.py | 45 +++++++- synapse/config/appservice.py | 2 +- synapse/config/captcha.py | 2 +- synapse/config/database.py | 2 +- synapse/config/key.py | 2 +- synapse/config/logger.py | 2 +- synapse/config/metrics.py | 8 +- synapse/config/ratelimiting.py | 2 +- synapse/config/registration.py | 2 +- synapse/config/repository.py | 2 +- synapse/config/saml2.py | 2 +- synapse/config/server.py | 2 +- synapse/config/tls.py | 2 +- synapse/config/voip.py | 2 +- synapse/storage/__init__.py | 20 +++- synapse/storage/events.py | 58 ++++++++++- synapse/storage/registration.py | 12 +++ .../storage/schema/delta/24/stats_reporting.sql | 22 ++++ tests/storage/event_injector.py | 81 ++++++++++++++ tests/storage/test_events.py | 116 +++++++++++++++++++++ tests/storage/test_room.py | 2 +- tests/storage/test_stream.py | 68 +++--------- 24 files changed, 425 insertions(+), 78 deletions(-) create mode 100644 synapse/storage/schema/delta/24/stats_reporting.sql create mode 100644 tests/storage/event_injector.py create mode 100644 tests/storage/test_events.py (limited to 'tests') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 15c0a4a003..b4429bd4f3 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -42,7 +42,7 @@ from synapse.storage import ( from synapse.server import HomeServer -from twisted.internet import reactor +from twisted.internet import reactor, task, defer from twisted.application import service from twisted.enterprise import adbapi from twisted.web.resource import Resource, EncodingResourceWrapper @@ -677,6 +677,39 @@ def run(hs): ThreadPool._worker = profile(ThreadPool._worker) reactor.run = profile(reactor.run) + start_time = hs.get_clock().time() + + @defer.inlineCallbacks + def phone_stats_home(): + now = int(hs.get_clock().time()) + uptime = int(now - start_time) + if uptime < 0: + uptime = 0 + + stats = {} + stats["homeserver"] = hs.config.server_name + stats["timestamp"] = now + stats["uptime_seconds"] = uptime + stats["total_users"] = yield hs.get_datastore().count_all_users() + + all_rooms = yield hs.get_datastore().get_rooms(False) + stats["total_room_count"] = len(all_rooms) + + stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() + daily_messages = yield hs.get_datastore().count_daily_messages() + if daily_messages is not None: + stats["daily_messages"] = daily_messages + + logger.info("Reporting stats to matrix.org: %s" % (stats,)) + hs.get_simple_http_client().put_json( + "https://matrix.org/report-usage-stats/push", + stats + ) + + if hs.config.report_stats: + phone_home_task = task.LoopingCall(phone_stats_home) + phone_home_task.start(60 * 60 * 24, now=False) + def in_thread(): with LoggingContext("run"): change_resource_limit(hs.config.soft_file_limit) diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index 1f7d543c31..6bcc437591 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -25,6 +25,7 @@ SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"] CONFIGFILE = "homeserver.yaml" GREEN = "\x1b[1;32m" +RED = "\x1b[1;31m" NORMAL = "\x1b[m" if not os.path.exists(CONFIGFILE): @@ -45,8 +46,15 @@ def start(): print "Starting ...", args = SYNAPSE args.extend(["--daemonize", "-c", CONFIGFILE]) - subprocess.check_call(args) - print GREEN + "started" + NORMAL + try: + subprocess.check_call(args) + print GREEN + "started" + NORMAL + except subprocess.CalledProcessError as e: + print ( + RED + + "error starting (exit code: %d); see above for logs" % e.returncode + + NORMAL + ) def stop(): diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 8a75c48733..b9983f72a2 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -26,6 +26,16 @@ class ConfigError(Exception): class Config(object): + stats_reporting_begging_spiel = ( + "We would really appreciate it if you could help our project out by " + "reporting anonymized usage statistics from your homeserver. Only very " + "basic aggregate data (e.g. number of users) will be reported, but it " + "helps us to track the growth of the Matrix community, and helps us to " + "make Matrix a success, as well as to convince other networks that they " + "should peer with us.\n" + "Thank you." + ) + @staticmethod def parse_size(value): if isinstance(value, int) or isinstance(value, long): @@ -111,11 +121,14 @@ class Config(object): results.append(getattr(cls, name)(self, *args, **kargs)) return results - def generate_config(self, config_dir_path, server_name): + def generate_config(self, config_dir_path, server_name, report_stats=None): default_config = "# vim:ft=yaml\n" default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( - "default_config", config_dir_path, server_name + "default_config", + config_dir_path=config_dir_path, + server_name=server_name, + report_stats=report_stats, )) config = yaml.load(default_config) @@ -139,6 +152,12 @@ class Config(object): action="store_true", help="Generate a config file for the server name" ) + config_parser.add_argument( + "--report-stats", + action="store", + help="Stuff", + choices=["yes", "no"] + ) config_parser.add_argument( "--generate-keys", action="store_true", @@ -189,6 +208,11 @@ class Config(object): config_files.append(config_path) if config_args.generate_config: + if config_args.report_stats is None: + config_parser.error( + "Please specify either --report-stats=yes or --report-stats=no\n\n" + + cls.stats_reporting_begging_spiel + ) if not config_files: config_parser.error( "Must supply a config file.\nA config file can be automatically" @@ -211,7 +235,9 @@ class Config(object): os.makedirs(config_dir_path) with open(config_path, "wb") as config_file: config_bytes, config = obj.generate_config( - config_dir_path, server_name + config_dir_path=config_dir_path, + server_name=server_name, + report_stats=(config_args.report_stats == "yes"), ) obj.invoke_all("generate_files", config) config_file.write(config_bytes) @@ -261,9 +287,20 @@ class Config(object): specified_config.update(yaml_config) server_name = specified_config["server_name"] - _, config = obj.generate_config(config_dir_path, server_name) + _, config = obj.generate_config( + config_dir_path=config_dir_path, + server_name=server_name + ) config.pop("log_config") config.update(specified_config) + if "report_stats" not in config: + sys.stderr.write( + "Please opt in or out of reporting anonymized homeserver usage " + "statistics, by setting the report_stats key in your config file " + " ( " + config_path + " ) " + + "to either True or False.\n\n" + + Config.stats_reporting_begging_spiel + "\n") + sys.exit(1) if generate_keys: obj.invoke_all("generate_files", config) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 38f41933b7..b8d301995e 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -20,7 +20,7 @@ class AppServiceConfig(Config): def read_config(self, config): self.app_service_config_files = config.get("app_service_config_files", []) - def default_config(cls, config_dir_path, server_name): + def default_config(cls, **kwargs): return """\ # A list of application service config file to use app_service_config_files: [] diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index 15a132b4e3..dd92fcd0dc 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -24,7 +24,7 @@ class CaptchaConfig(Config): self.captcha_bypass_secret = config.get("captcha_bypass_secret") self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"] - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): return """\ ## Captcha ## diff --git a/synapse/config/database.py b/synapse/config/database.py index f0611e8884..baeda8f300 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -45,7 +45,7 @@ class DatabaseConfig(Config): self.set_databasepath(config.get("database_path")) - def default_config(self, config, config_dir_path): + def default_config(self, **kwargs): database_path = self.abspath("homeserver.db") return """\ # Database configuration diff --git a/synapse/config/key.py b/synapse/config/key.py index 23ac8a3fca..2c187065e5 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -40,7 +40,7 @@ class KeyConfig(Config): config["perspectives"] ) - def default_config(self, config_dir_path, server_name): + def default_config(self, config_dir_path, server_name, **kwargs): base_key_name = os.path.join(config_dir_path, server_name) return """\ ## Signing Keys ## diff --git a/synapse/config/logger.py b/synapse/config/logger.py index daca698d0c..bd0c17c861 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -70,7 +70,7 @@ class LoggingConfig(Config): self.log_config = self.abspath(config.get("log_config")) self.log_file = self.abspath(config.get("log_file")) - def default_config(self, config_dir_path, server_name): + def default_config(self, config_dir_path, server_name, **kwargs): log_file = self.abspath("homeserver.log") log_config = self.abspath( os.path.join(config_dir_path, server_name + ".log.config") diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index ae5a691527..825fec9a38 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -19,13 +19,15 @@ from ._base import Config class MetricsConfig(Config): def read_config(self, config): self.enable_metrics = config["enable_metrics"] + self.report_stats = config.get("report_stats", None) self.metrics_port = config.get("metrics_port") self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1") - def default_config(self, config_dir_path, server_name): - return """\ + def default_config(self, report_stats=None, **kwargs): + suffix = "" if report_stats is None else "report_stats: %(report_stats)s\n" + return ("""\ ## Metrics ### # Enable collection and rendering of performance metrics enable_metrics: False - """ + """ + suffix) % locals() diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 76d9970e5b..611b598ec7 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -27,7 +27,7 @@ class RatelimitConfig(Config): self.federation_rc_reject_limit = config["federation_rc_reject_limit"] self.federation_rc_concurrent = config["federation_rc_concurrent"] - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): return """\ ## Ratelimiting ## diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 62de4b399f..fa98eced34 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -34,7 +34,7 @@ class RegistrationConfig(Config): self.registration_shared_secret = config.get("registration_shared_secret") self.macaroon_secret_key = config.get("macaroon_secret_key") - def default_config(self, config_dir, server_name): + def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) macaroon_secret_key = random_string_with_symbols(50) return """\ diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 64644b9a7a..2fcf872449 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -60,7 +60,7 @@ class ContentRepositoryConfig(Config): config["thumbnail_sizes"] ) - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): media_store = self.default_path("media_store") uploads_path = self.default_path("uploads") return """ diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index 1532036876..4c6133cf22 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -41,7 +41,7 @@ class SAML2Config(Config): self.saml2_config_path = None self.saml2_idp_redirect_url = None - def default_config(self, config_dir_path, server_name): + def default_config(self, config_dir_path, server_name, **kwargs): return """ # Enable SAML2 for registration and login. Uses pysaml2 # config_path: Path to the sp_conf.py configuration file diff --git a/synapse/config/server.py b/synapse/config/server.py index a03e55c223..4d12d49857 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -117,7 +117,7 @@ class ServerConfig(Config): self.content_addr = content_addr - def default_config(self, config_dir_path, server_name): + def default_config(self, server_name, **kwargs): if ":" in server_name: bind_port = int(server_name.split(":")[1]) unsecure_port = bind_port - 400 diff --git a/synapse/config/tls.py b/synapse/config/tls.py index e6023a718d..0ac2698293 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -50,7 +50,7 @@ class TlsConfig(Config): "use_insecure_ssl_client_just_for_testing_do_not_use" ) - def default_config(self, config_dir_path, server_name): + def default_config(self, config_dir_path, server_name, **kwargs): base_key_name = os.path.join(config_dir_path, server_name) tls_certificate_path = base_key_name + ".tls.crt" diff --git a/synapse/config/voip.py b/synapse/config/voip.py index a1707223d3..a093354ccd 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -22,7 +22,7 @@ class VoipConfig(Config): self.turn_shared_secret = config["turn_shared_secret"] self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"]) - def default_config(self, config_dir_path, server_name): + def default_config(self, **kwargs): return """\ ## Turn ## diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 77cb1dbd81..b64c90d631 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 23 +SCHEMA_VERSION = 24 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -126,6 +126,24 @@ class DataStore(RoomMemberStore, RoomStore, lock=False, ) + @defer.inlineCallbacks + def count_daily_users(self): + def _count_users(txn): + txn.execute( + "SELECT COUNT(DISTINCT user_id) AS users" + " FROM user_ips" + " WHERE last_seen > ?", + # This is close enough to a day for our purposes. + (int(self._clock.time_msec()) - (1000 * 60 * 60 * 24),) + ) + rows = self.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.runInteraction("count_users", _count_users) + defer.returnValue(ret) + def get_user_ip_and_agents(self, user): return self._simple_select_list( table="user_ips", diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 0a477e3122..2b51db9940 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from _base import SQLBaseStore, _RollbackButIsFineException from twisted.internet import defer, reactor @@ -28,6 +27,7 @@ from canonicaljson import encode_canonical_json from contextlib import contextmanager import logging +import math import ujson as json logger = logging.getLogger(__name__) @@ -905,3 +905,59 @@ class EventsStore(SQLBaseStore): txn.execute(sql, (event.event_id,)) result = txn.fetchone() return result[0] if result else None + + @defer.inlineCallbacks + def count_daily_messages(self): + def _count_messages(txn): + now = self.hs.get_clock().time() + + txn.execute( + "SELECT reported_stream_token, reported_time FROM stats_reporting" + ) + last_reported = self.cursor_to_dict(txn) + + txn.execute( + "SELECT stream_ordering" + " FROM events" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + now_reporting = self.cursor_to_dict(txn) + if not now_reporting: + return None + now_reporting = now_reporting[0]["stream_ordering"] + + txn.execute("DELETE FROM stats_reporting") + txn.execute( + "INSERT INTO stats_reporting" + " (reported_stream_token, reported_time)" + " VALUES (?, ?)", + (now_reporting, now,) + ) + + if not last_reported: + return None + + # Close enough to correct for our purposes. + yesterday = (now - 24 * 60 * 60) + if math.fabs(yesterday - last_reported[0]["reported_time"]) > 60 * 60: + return None + + txn.execute( + "SELECT COUNT(*) as messages" + " FROM events NATURAL JOIN event_json" + " WHERE json like '%m.room.message%'" + " AND stream_ordering > ?" + " AND stream_ordering <= ?", + ( + last_reported[0]["reported_stream_token"], + now_reporting, + ) + ) + rows = self.cursor_to_dict(txn) + if not rows: + return None + return rows[0]["messages"] + + ret = yield self.runInteraction("count_messages", _count_messages) + defer.returnValue(ret) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index c9ceb132ae..6d76237658 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -289,3 +289,15 @@ class RegistrationStore(SQLBaseStore): if ret: defer.returnValue(ret['user_id']) defer.returnValue(None) + + @defer.inlineCallbacks + def count_all_users(self): + def _count_users(txn): + txn.execute("SELECT COUNT(*) AS users FROM users") + rows = self.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.runInteraction("count_users", _count_users) + defer.returnValue(ret) diff --git a/synapse/storage/schema/delta/24/stats_reporting.sql b/synapse/storage/schema/delta/24/stats_reporting.sql new file mode 100644 index 0000000000..e9165d2917 --- /dev/null +++ b/synapse/storage/schema/delta/24/stats_reporting.sql @@ -0,0 +1,22 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Should only ever contain one row +CREATE TABLE IF NOT EXISTS stats_reporting( + -- The stream ordering token which was most recently reported as stats + reported_stream_token INTEGER, + -- The time (seconds since epoch) stats were most recently reported + reported_time BIGINT +); diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py new file mode 100644 index 0000000000..42bd8928bd --- /dev/null +++ b/tests/storage/event_injector.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests import unittest +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.types import UserID, RoomID + +from tests.utils import setup_test_homeserver + +from mock import Mock + + +class EventInjector: + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.message_handler = hs.get_handlers().message_handler + self.event_builder_factory = hs.get_event_builder_factory() + + @defer.inlineCallbacks + def create_room(self, room): + builder = self.event_builder_factory.new({ + "type": EventTypes.Create, + "room_id": room.to_string(), + "content": {}, + }) + + event, context = yield self.message_handler._create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) + + @defer.inlineCallbacks + def inject_room_member(self, room, user, membership): + builder = self.event_builder_factory.new({ + "type": EventTypes.Member, + "sender": user.to_string(), + "state_key": user.to_string(), + "room_id": room.to_string(), + "content": {"membership": membership}, + }) + + event, context = yield self.message_handler._create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) + + defer.returnValue(event) + + @defer.inlineCallbacks + def inject_message(self, room, user, body): + builder = self.event_builder_factory.new({ + "type": EventTypes.Message, + "sender": user.to_string(), + "state_key": user.to_string(), + "room_id": room.to_string(), + "content": {"body": body, "msgtype": u"message"}, + }) + + event, context = yield self.message_handler._create_new_client_event( + builder + ) + + yield self.store.persist_event(event, context) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py new file mode 100644 index 0000000000..313013009e --- /dev/null +++ b/tests/storage/test_events.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import uuid +from mock.mock import Mock +from synapse.types import RoomID, UserID + +from tests import unittest +from twisted.internet import defer +from tests.storage.event_injector import EventInjector + +from tests.utils import setup_test_homeserver + + +class EventsStoreTestCase(unittest.TestCase): + + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver( + resource_for_federation=Mock(), + http_client=None, + ) + self.store = self.hs.get_datastore() + self.db_pool = self.hs.get_db_pool() + self.message_handler = self.hs.get_handlers().message_handler + self.event_injector = EventInjector(self.hs) + + @defer.inlineCallbacks + def test_count_daily_messages(self): + self.db_pool.runQuery("DELETE FROM stats_reporting") + + self.hs.clock.now = 100 + + # Never reported before, and nothing which could be reported + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + count = yield self.db_pool.runQuery("SELECT COUNT(*) FROM stats_reporting") + self.assertEqual([(0,)], count) + + # Create something to report + room = RoomID.from_string("!abc123:test") + user = UserID.from_string("@raccoonlover:test") + yield self.event_injector.create_room(room) + + self.base_event = yield self._get_last_stream_token() + + yield self.event_injector.inject_message(room, user, "Raccoons are really cute") + + # Never reported before, something could be reported, but isn't because + # it isn't old enough. + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + self._assert_stats_reporting(1, self.hs.clock.now) + + # Already reported yesterday, two new events from today. + yield self.event_injector.inject_message(room, user, "Yeah they are!") + yield self.event_injector.inject_message(room, user, "Incredibly!") + self.hs.clock.now += 60 * 60 * 24 + count = yield self.store.count_daily_messages() + self.assertEqual(2, count) # 2 since yesterday + self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever + + # Last reported too recently. + yield self.event_injector.inject_message(room, user, "Who could disagree?") + self.hs.clock.now += 60 * 60 * 22 + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + self._assert_stats_reporting(4, self.hs.clock.now) + + # Last reported too long ago + yield self.event_injector.inject_message(room, user, "No one.") + self.hs.clock.now += 60 * 60 * 26 + count = yield self.store.count_daily_messages() + self.assertIsNone(count) + self._assert_stats_reporting(5, self.hs.clock.now) + + # And now let's actually report something + yield self.event_injector.inject_message(room, user, "Indeed.") + yield self.event_injector.inject_message(room, user, "Indeed.") + yield self.event_injector.inject_message(room, user, "Indeed.") + # A little over 24 hours is fine :) + self.hs.clock.now += (60 * 60 * 24) + 50 + count = yield self.store.count_daily_messages() + self.assertEqual(3, count) + self._assert_stats_reporting(8, self.hs.clock.now) + + @defer.inlineCallbacks + def _get_last_stream_token(self): + rows = yield self.db_pool.runQuery( + "SELECT stream_ordering" + " FROM events" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + if not rows: + defer.returnValue(0) + else: + defer.returnValue(rows[0][0]) + + @defer.inlineCallbacks + def _assert_stats_reporting(self, messages, time): + rows = yield self.db_pool.runQuery( + "SELECT reported_stream_token, reported_time FROM stats_reporting" + ) + self.assertEqual([(self.base_event + messages, time,)], rows) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index ab7625a3ca..caffce64e3 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -85,7 +85,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastore() - self.event_factory = hs.get_event_factory(); + self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 0c9b89d765..a658a789aa 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.types import UserID, RoomID +from tests.storage.event_injector import EventInjector from tests.utils import setup_test_homeserver @@ -36,6 +37,7 @@ class StreamStoreTestCase(unittest.TestCase): self.store = hs.get_datastore() self.event_builder_factory = hs.get_event_builder_factory() + self.event_injector = EventInjector(hs) self.handlers = hs.get_handlers() self.message_handler = self.handlers.message_handler @@ -45,60 +47,20 @@ class StreamStoreTestCase(unittest.TestCase): self.room1 = RoomID.from_string("!abc123:test") self.room2 = RoomID.from_string("!xyx987:test") - self.depth = 1 - - @defer.inlineCallbacks - def inject_room_member(self, room, user, membership): - self.depth += 1 - - builder = self.event_builder_factory.new({ - "type": EventTypes.Member, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"membership": membership}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - - defer.returnValue(event) - - @defer.inlineCallbacks - def inject_message(self, room, user, body): - self.depth += 1 - - builder = self.event_builder_factory.new({ - "type": EventTypes.Message, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"body": body, "msgtype": u"message"}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - @defer.inlineCallbacks def test_event_stream_get_other(self): # Both bob and alice joins the room - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) # Initial stream key: start = yield self.store.get_room_events_max_id() - yield self.inject_message(self.room1, self.u_alice, u"test") + yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") end = yield self.store.get_room_events_max_id() @@ -125,17 +87,17 @@ class StreamStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_event_stream_get_own(self): # Both bob and alice joins the room - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) # Initial stream key: start = yield self.store.get_room_events_max_id() - yield self.inject_message(self.room1, self.u_alice, u"test") + yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") end = yield self.store.get_room_events_max_id() @@ -162,22 +124,22 @@ class StreamStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_event_stream_join_leave(self): # Both bob and alice joins the room - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) # Then bob leaves again. - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.LEAVE ) # Initial stream key: start = yield self.store.get_room_events_max_id() - yield self.inject_message(self.room1, self.u_alice, u"test") + yield self.event_injector.inject_message(self.room1, self.u_alice, u"test") end = yield self.store.get_room_events_max_id() @@ -193,17 +155,17 @@ class StreamStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_event_stream_prev_content(self): - yield self.inject_room_member( + yield self.event_injector.inject_room_member( self.room1, self.u_bob, Membership.JOIN ) - event1 = yield self.inject_room_member( + event1 = yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN ) start = yield self.store.get_room_events_max_id() - event2 = yield self.inject_room_member( + event2 = yield self.event_injector.inject_room_member( self.room1, self.u_alice, Membership.JOIN, ) -- cgit 1.5.1 From bb4dddd6c4f85bc5b07119d3f9dec31964b5b6f9 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 22 Sep 2015 18:33:34 +0100 Subject: Move NullSource out of synapse and into tests since it is only used by the tests --- synapse/streams/events.py | 16 ---------------- tests/rest/client/v1/test_presence.py | 18 +++++++++++++++++- 2 files changed, 17 insertions(+), 17 deletions(-) (limited to 'tests') diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 8671a8fa4e..699083ae12 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -23,22 +23,6 @@ from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.receipts import ReceiptEventSource -class NullSource(object): - """This event source never yields any events and its token remains at - zero. It may be useful for unit-testing.""" - def __init__(self, hs): - pass - - def get_new_events_for_user(self, user, from_key, limit): - return defer.succeed(([], from_key)) - - def get_current_key(self, direction='f'): - return defer.succeed(0) - - def get_pagination_rows(self, user, pagination_config, key): - return defer.succeed(([], pagination_config.from_key)) - - class EventSources(object): SOURCE_TYPES = { "room": RoomEventSource, diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 2ee3da0b34..29d9bbaad4 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -41,6 +41,22 @@ myid = "@apple:test" PATH_PREFIX = "/_matrix/client/api/v1" +class NullSource(object): + """This event source never yields any events and its token remains at + zero. It may be useful for unit-testing.""" + def __init__(self, hs): + pass + + def get_new_events_for_user(self, user, from_key, limit): + return defer.succeed(([], from_key)) + + def get_current_key(self, direction='f'): + return defer.succeed(0) + + def get_pagination_rows(self, user, pagination_config, key): + return defer.succeed(([], pagination_config.from_key)) + + class JustPresenceHandlers(object): def __init__(self, hs): self.presence_handler = PresenceHandler(hs) @@ -243,7 +259,7 @@ class PresenceEventStreamTestCase(unittest.TestCase): # HIDEOUS HACKERY # TODO(paul): This should be injected in via the HomeServer DI system from synapse.streams.events import ( - PresenceEventSource, NullSource, EventSources + PresenceEventSource, EventSources ) old_SOURCE_TYPES = EventSources.SOURCE_TYPES -- cgit 1.5.1 From 1ee3d26432d87ff312350f21da982f646b5af49a Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 23 Sep 2015 10:30:03 +0100 Subject: synapse/storage/_base.py:_simple_selectupdate_one was unused --- synapse/storage/_base.py | 31 ------------------------------- tests/storage/test_base.py | 20 -------------------- 2 files changed, 51 deletions(-) (limited to 'tests') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index cf4ec30f48..79021bde6b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -686,37 +686,6 @@ class SQLBaseStore(object): return dict(zip(retcols, row)) - def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None, - retcols=None, allow_none=False, - desc="_simple_selectupdate_one"): - """ Combined SELECT then UPDATE.""" - def func(txn): - ret = None - if retcols: - ret = self._simple_select_one_txn( - txn, - table=table, - keyvalues=keyvalues, - retcols=retcols, - allow_none=allow_none, - ) - - if updatevalues: - self._simple_update_one_txn( - txn, - table=table, - keyvalues=keyvalues, - updatevalues=updatevalues, - ) - - # if txn.rowcount == 0: - # raise StoreError(404, "No row found") - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched") - - return ret - return self.runInteraction(desc, func) - def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): """Executes a DELETE query on the named table, expecting to delete a single row. diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 8573f18b55..1ddca1da4c 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -185,26 +185,6 @@ class SQLBaseStoreTestCase(unittest.TestCase): [3, 4, 1, 2] ) - @defer.inlineCallbacks - def test_update_one_with_return(self): - self.mock_txn.rowcount = 1 - self.mock_txn.fetchone.return_value = ("Old Value",) - - ret = yield self.datastore._simple_selectupdate_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - updatevalues={"columname": "New Value"}, - retcols=["columname"] - ) - - self.assertEquals({"columname": "Old Value"}, ret) - self.mock_txn.execute.assert_has_calls([ - call('SELECT columname FROM tablename WHERE keycol = ?', - ['TheKey']), - call("UPDATE tablename SET columname = ? WHERE keycol = ?", - ["New Value", "TheKey"]) - ]) - @defer.inlineCallbacks def test_delete_one(self): self.mock_txn.rowcount = 1 -- cgit 1.5.1 From ec398af41c4d276abb02279efbcbb0aa08a4cbc8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 13 Oct 2015 11:41:04 +0100 Subject: Expose error more nicely --- synapse/app/homeserver.py | 5 +- synapse/storage/__init__.py | 3 - synapse/storage/_schema_prepare.py | 395 ------------------------------------ synapse/storage/engines/postgres.py | 2 +- synapse/storage/engines/sqlite3.py | 2 +- synapse/storage/schema_prepare.py | 395 ++++++++++++++++++++++++++++++++++++ tests/utils.py | 2 +- 7 files changed, 400 insertions(+), 404 deletions(-) delete mode 100644 synapse/storage/_schema_prepare.py create mode 100644 synapse/storage/schema_prepare.py (limited to 'tests') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 190b03e2f7..b284d07cf0 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -35,9 +35,8 @@ if __name__ == '__main__': from synapse.storage.engines import create_engine, IncorrectDatabaseSetup -from synapse.storage import ( - are_all_users_on_domain, UpgradeDatabaseException, -) +from synapse.storage import are_all_users_on_domain +from synapse.storage.schema_prepare import UpgradeDatabaseException from synapse.server import HomeServer diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 4be629bff8..48a0633746 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -41,9 +41,6 @@ from .end_to_end_keys import EndToEndKeyStore from .receipts import ReceiptsStore -from ._schema_prepare import UpgradeDatabaseException - -__all__ = [UpgradeDatabaseException] import logging diff --git a/synapse/storage/_schema_prepare.py b/synapse/storage/_schema_prepare.py deleted file mode 100644 index 1ddf55be4d..0000000000 --- a/synapse/storage/_schema_prepare.py +++ /dev/null @@ -1,395 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import fnmatch -import imp -import logging -import os -import re - - -logger = logging.getLogger(__name__) - - -# Remember to update this number every time a change is made to database -# schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 24 - -dir_path = os.path.abspath(os.path.dirname(__file__)) - - -def read_schema(path): - """ Read the named database schema. - - Args: - path: Path of the database schema. - Returns: - A string containing the database schema. - """ - with open(path) as schema_file: - return schema_file.read() - - -class PrepareDatabaseException(Exception): - pass - - -class UpgradeDatabaseException(PrepareDatabaseException): - pass - - -def prepare_database(db_conn, database_engine): - """Prepares a database for usage. Will either create all necessary tables - or upgrade from an older schema version. - """ - try: - cur = db_conn.cursor() - version_info = _get_or_create_schema_state(cur, database_engine) - - if version_info: - user_version, delta_files, upgraded = version_info - _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine - ) - else: - _setup_new_database(cur, database_engine) - - # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) - - cur.close() - db_conn.commit() - except: - db_conn.rollback() - raise - - -def _setup_new_database(cur, database_engine): - """Sets up the database by finding a base set of "full schemas" and then - applying any necessary deltas. - - The "full_schemas" directory has subdirectories named after versions. This - function searches for the highest version less than or equal to - `SCHEMA_VERSION` and executes all .sql files in that directory. - - The function will then apply all deltas for all versions after the base - version. - - Example directory structure: - - schema/ - delta/ - ... - full_schemas/ - 3/ - test.sql - ... - 11/ - foo.sql - bar.sql - ... - - In the example foo.sql and bar.sql would be run, and then any delta files - for versions strictly greater than 11. - """ - current_dir = os.path.join(dir_path, "schema", "full_schemas") - directory_entries = os.listdir(current_dir) - - valid_dirs = [] - pattern = re.compile(r"^\d+(\.sql)?$") - for filename in directory_entries: - match = pattern.match(filename) - abs_path = os.path.join(current_dir, filename) - if match and os.path.isdir(abs_path): - ver = int(match.group(0)) - if ver <= SCHEMA_VERSION: - valid_dirs.append((ver, abs_path)) - else: - logger.warn("Unexpected entry in 'full_schemas': %s", filename) - - if not valid_dirs: - raise PrepareDatabaseException( - "Could not find a suitable base set of full schemas" - ) - - max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) - - logger.debug("Initialising schema v%d", max_current_ver) - - directory_entries = os.listdir(sql_dir) - - for filename in fnmatch.filter(directory_entries, "*.sql"): - sql_loc = os.path.join(sql_dir, filename) - logger.debug("Applying schema %s", sql_loc) - executescript(cur, sql_loc) - - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)" - ), - (max_current_ver, False,) - ) - - _upgrade_existing_database( - cur, - current_version=max_current_ver, - applied_delta_files=[], - upgraded=False, - database_engine=database_engine, - ) - - -def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine): - """Upgrades an existing database. - - Delta files can either be SQL stored in *.sql files, or python modules - in *.py. - - There can be multiple delta files per version. Synapse will keep track of - which delta files have been applied, and will apply any that haven't been - even if there has been no version bump. This is useful for development - where orthogonal schema changes may happen on separate branches. - - Different delta files for the same version *must* be orthogonal and give - the same result when applied in any order. No guarantees are made on the - order of execution of these scripts. - - This is a no-op of current_version == SCHEMA_VERSION. - - Example directory structure: - - schema/ - delta/ - 11/ - foo.sql - ... - 12/ - foo.sql - bar.py - ... - full_schemas/ - ... - - In the example, if current_version is 11, then foo.sql will be run if and - only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in - some arbitrary order. - - Args: - cur (Cursor) - current_version (int): The current version of the schema. - applied_delta_files (list): A list of deltas that have already been - applied. - upgraded (bool): Whether the current version was generated by having - applied deltas or from full schema file. If `True` the function - will never apply delta files for the given `current_version`, since - the current_version wasn't generated by applying those delta files. - """ - - if current_version > SCHEMA_VERSION: - raise ValueError( - "Cannot use this database as it is too " + - "new for the server to understand" - ) - - start_ver = current_version - if not upgraded: - start_ver += 1 - - logger.debug("applied_delta_files: %s", applied_delta_files) - - for v in range(start_ver, SCHEMA_VERSION + 1): - logger.debug("Upgrading schema to v%d", v) - - delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) - - try: - directory_entries = os.listdir(delta_dir) - except OSError: - logger.exception("Could not open delta dir for version %d", v) - raise UpgradeDatabaseException( - "Could not open delta dir for version %d" % (v,) - ) - - directory_entries.sort() - for file_name in directory_entries: - relative_path = os.path.join(str(v), file_name) - logger.debug("Found file: %s", relative_path) - if relative_path in applied_delta_files: - continue - - absolute_path = os.path.join( - dir_path, "schema", "delta", relative_path, - ) - root_name, ext = os.path.splitext(file_name) - if ext == ".py": - # This is a python upgrade module. We need to import into some - # package and then execute its `run_upgrade` function. - module_name = "synapse.storage.v%d_%s" % ( - v, root_name - ) - with open(absolute_path) as python_file: - module = imp.load_source( - module_name, absolute_path, python_file - ) - logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine) - elif ext == ".pyc": - # Sometimes .pyc files turn up anyway even though we've - # disabled their generation; e.g. from distribution package - # installers. Silently skip it - pass - elif ext == ".sql": - # A plain old .sql file, just read and execute it - logger.debug("Applying schema %s", relative_path) - executescript(cur, absolute_path) - else: - # Not a valid delta file. - logger.warn( - "Found directory entry that did not end in .py or" - " .sql: %s", - relative_path, - ) - continue - - # Mark as done. - cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_schema_deltas (version, file)" - " VALUES (?,?)", - ), - (v, relative_path) - ) - - cur.execute("DELETE FROM schema_version") - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)", - ), - (v, True) - ) - - -def get_statements(f): - statement_buffer = "" - in_comment = False # If we're in a /* ... */ style comment - - for line in f: - line = line.strip() - - if in_comment: - # Check if this line contains an end to the comment - comments = line.split("*/", 1) - if len(comments) == 1: - continue - line = comments[1] - in_comment = False - - # Remove inline block comments - line = re.sub(r"/\*.*\*/", " ", line) - - # Does this line start a comment? - comments = line.split("/*", 1) - if len(comments) > 1: - line = comments[0] - in_comment = True - - # Deal with line comments - line = line.split("--", 1)[0] - line = line.split("//", 1)[0] - - # Find *all* semicolons. We need to treat first and last entry - # specially. - statements = line.split(";") - - # We must prepend statement_buffer to the first statement - first_statement = "%s %s" % ( - statement_buffer.strip(), - statements[0].strip() - ) - statements[0] = first_statement - - # Every entry, except the last, is a full statement - for statement in statements[:-1]: - yield statement.strip() - - # The last entry did *not* end in a semicolon, so we store it for the - # next semicolon we find - statement_buffer = statements[-1].strip() - - -def executescript(txn, schema_path): - with open(schema_path, 'r') as f: - for statement in get_statements(f): - txn.execute(statement) - - -def _get_or_create_schema_state(txn, database_engine): - # Bluntly try creating the schema_version tables. - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - executescript(txn, schema_path) - - txn.execute("SELECT version, upgraded FROM schema_version") - row = txn.fetchone() - current_version = int(row[0]) if row else None - upgraded = bool(row[1]) if row else None - - if current_version: - txn.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_schema_deltas WHERE version >= ?" - ), - (current_version,) - ) - applied_deltas = [d for d, in txn.fetchall()] - return current_version, applied_deltas, upgraded - - return None - - -def prepare_sqlite3_database(db_conn): - """This function should be called before `prepare_database` on sqlite3 - databases. - - Since we changed the way we store the current schema version and handle - updates to schemas, we need a way to upgrade from the old method to the - new. This only affects sqlite databases since they were the only ones - supported at the time. - """ - with db_conn: - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - create_schema = read_schema(schema_path) - db_conn.executescript(create_schema) - - c = db_conn.execute("SELECT * FROM schema_version") - rows = c.fetchall() - c.close() - - if not rows: - c = db_conn.execute("PRAGMA user_version") - row = c.fetchone() - c.close() - - if row and row[0]: - db_conn.execute( - "REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (row[0], False) - ) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 949396044e..7e45dabf4c 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage._schema_prepare import prepare_database +from synapse.storage.schema_prepare import prepare_database from ._base import IncorrectDatabaseSetup diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index a66815ef2d..0eeaa45d19 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage._schema_prepare import ( +from synapse.storage.schema_prepare import ( prepare_database, prepare_sqlite3_database ) diff --git a/synapse/storage/schema_prepare.py b/synapse/storage/schema_prepare.py new file mode 100644 index 0000000000..1ddf55be4d --- /dev/null +++ b/synapse/storage/schema_prepare.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fnmatch +import imp +import logging +import os +import re + + +logger = logging.getLogger(__name__) + + +# Remember to update this number every time a change is made to database +# schema files, so the users will be informed on server restarts. +SCHEMA_VERSION = 24 + +dir_path = os.path.abspath(os.path.dirname(__file__)) + + +def read_schema(path): + """ Read the named database schema. + + Args: + path: Path of the database schema. + Returns: + A string containing the database schema. + """ + with open(path) as schema_file: + return schema_file.read() + + +class PrepareDatabaseException(Exception): + pass + + +class UpgradeDatabaseException(PrepareDatabaseException): + pass + + +def prepare_database(db_conn, database_engine): + """Prepares a database for usage. Will either create all necessary tables + or upgrade from an older schema version. + """ + try: + cur = db_conn.cursor() + version_info = _get_or_create_schema_state(cur, database_engine) + + if version_info: + user_version, delta_files, upgraded = version_info + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine + ) + else: + _setup_new_database(cur, database_engine) + + # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + + cur.close() + db_conn.commit() + except: + db_conn.rollback() + raise + + +def _setup_new_database(cur, database_engine): + """Sets up the database by finding a base set of "full schemas" and then + applying any necessary deltas. + + The "full_schemas" directory has subdirectories named after versions. This + function searches for the highest version less than or equal to + `SCHEMA_VERSION` and executes all .sql files in that directory. + + The function will then apply all deltas for all versions after the base + version. + + Example directory structure: + + schema/ + delta/ + ... + full_schemas/ + 3/ + test.sql + ... + 11/ + foo.sql + bar.sql + ... + + In the example foo.sql and bar.sql would be run, and then any delta files + for versions strictly greater than 11. + """ + current_dir = os.path.join(dir_path, "schema", "full_schemas") + directory_entries = os.listdir(current_dir) + + valid_dirs = [] + pattern = re.compile(r"^\d+(\.sql)?$") + for filename in directory_entries: + match = pattern.match(filename) + abs_path = os.path.join(current_dir, filename) + if match and os.path.isdir(abs_path): + ver = int(match.group(0)) + if ver <= SCHEMA_VERSION: + valid_dirs.append((ver, abs_path)) + else: + logger.warn("Unexpected entry in 'full_schemas': %s", filename) + + if not valid_dirs: + raise PrepareDatabaseException( + "Could not find a suitable base set of full schemas" + ) + + max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) + + logger.debug("Initialising schema v%d", max_current_ver) + + directory_entries = os.listdir(sql_dir) + + for filename in fnmatch.filter(directory_entries, "*.sql"): + sql_loc = os.path.join(sql_dir, filename) + logger.debug("Applying schema %s", sql_loc) + executescript(cur, sql_loc) + + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)" + ), + (max_current_ver, False,) + ) + + _upgrade_existing_database( + cur, + current_version=max_current_ver, + applied_delta_files=[], + upgraded=False, + database_engine=database_engine, + ) + + +def _upgrade_existing_database(cur, current_version, applied_delta_files, + upgraded, database_engine): + """Upgrades an existing database. + + Delta files can either be SQL stored in *.sql files, or python modules + in *.py. + + There can be multiple delta files per version. Synapse will keep track of + which delta files have been applied, and will apply any that haven't been + even if there has been no version bump. This is useful for development + where orthogonal schema changes may happen on separate branches. + + Different delta files for the same version *must* be orthogonal and give + the same result when applied in any order. No guarantees are made on the + order of execution of these scripts. + + This is a no-op of current_version == SCHEMA_VERSION. + + Example directory structure: + + schema/ + delta/ + 11/ + foo.sql + ... + 12/ + foo.sql + bar.py + ... + full_schemas/ + ... + + In the example, if current_version is 11, then foo.sql will be run if and + only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in + some arbitrary order. + + Args: + cur (Cursor) + current_version (int): The current version of the schema. + applied_delta_files (list): A list of deltas that have already been + applied. + upgraded (bool): Whether the current version was generated by having + applied deltas or from full schema file. If `True` the function + will never apply delta files for the given `current_version`, since + the current_version wasn't generated by applying those delta files. + """ + + if current_version > SCHEMA_VERSION: + raise ValueError( + "Cannot use this database as it is too " + + "new for the server to understand" + ) + + start_ver = current_version + if not upgraded: + start_ver += 1 + + logger.debug("applied_delta_files: %s", applied_delta_files) + + for v in range(start_ver, SCHEMA_VERSION + 1): + logger.debug("Upgrading schema to v%d", v) + + delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) + + try: + directory_entries = os.listdir(delta_dir) + except OSError: + logger.exception("Could not open delta dir for version %d", v) + raise UpgradeDatabaseException( + "Could not open delta dir for version %d" % (v,) + ) + + directory_entries.sort() + for file_name in directory_entries: + relative_path = os.path.join(str(v), file_name) + logger.debug("Found file: %s", relative_path) + if relative_path in applied_delta_files: + continue + + absolute_path = os.path.join( + dir_path, "schema", "delta", relative_path, + ) + root_name, ext = os.path.splitext(file_name) + if ext == ".py": + # This is a python upgrade module. We need to import into some + # package and then execute its `run_upgrade` function. + module_name = "synapse.storage.v%d_%s" % ( + v, root_name + ) + with open(absolute_path) as python_file: + module = imp.load_source( + module_name, absolute_path, python_file + ) + logger.debug("Running script %s", relative_path) + module.run_upgrade(cur, database_engine) + elif ext == ".pyc": + # Sometimes .pyc files turn up anyway even though we've + # disabled their generation; e.g. from distribution package + # installers. Silently skip it + pass + elif ext == ".sql": + # A plain old .sql file, just read and execute it + logger.debug("Applying schema %s", relative_path) + executescript(cur, absolute_path) + else: + # Not a valid delta file. + logger.warn( + "Found directory entry that did not end in .py or" + " .sql: %s", + relative_path, + ) + continue + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_schema_deltas (version, file)" + " VALUES (?,?)", + ), + (v, relative_path) + ) + + cur.execute("DELETE FROM schema_version") + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)", + ), + (v, True) + ) + + +def get_statements(f): + statement_buffer = "" + in_comment = False # If we're in a /* ... */ style comment + + for line in f: + line = line.strip() + + if in_comment: + # Check if this line contains an end to the comment + comments = line.split("*/", 1) + if len(comments) == 1: + continue + line = comments[1] + in_comment = False + + # Remove inline block comments + line = re.sub(r"/\*.*\*/", " ", line) + + # Does this line start a comment? + comments = line.split("/*", 1) + if len(comments) > 1: + line = comments[0] + in_comment = True + + # Deal with line comments + line = line.split("--", 1)[0] + line = line.split("//", 1)[0] + + # Find *all* semicolons. We need to treat first and last entry + # specially. + statements = line.split(";") + + # We must prepend statement_buffer to the first statement + first_statement = "%s %s" % ( + statement_buffer.strip(), + statements[0].strip() + ) + statements[0] = first_statement + + # Every entry, except the last, is a full statement + for statement in statements[:-1]: + yield statement.strip() + + # The last entry did *not* end in a semicolon, so we store it for the + # next semicolon we find + statement_buffer = statements[-1].strip() + + +def executescript(txn, schema_path): + with open(schema_path, 'r') as f: + for statement in get_statements(f): + txn.execute(statement) + + +def _get_or_create_schema_state(txn, database_engine): + # Bluntly try creating the schema_version tables. + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + executescript(txn, schema_path) + + txn.execute("SELECT version, upgraded FROM schema_version") + row = txn.fetchone() + current_version = int(row[0]) if row else None + upgraded = bool(row[1]) if row else None + + if current_version: + txn.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_schema_deltas WHERE version >= ?" + ), + (current_version,) + ) + applied_deltas = [d for d, in txn.fetchall()] + return current_version, applied_deltas, upgraded + + return None + + +def prepare_sqlite3_database(db_conn): + """This function should be called before `prepare_database` on sqlite3 + databases. + + Since we changed the way we store the current schema version and handle + updates to schemas, we need a way to upgrade from the old method to the + new. This only affects sqlite databases since they were the only ones + supported at the time. + """ + with db_conn: + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + create_schema = read_schema(schema_path) + db_conn.executescript(create_schema) + + c = db_conn.execute("SELECT * FROM schema_version") + rows = c.fetchall() + c.close() + + if not rows: + c = db_conn.execute("PRAGMA user_version") + row = c.fetchone() + c.close() + + if row and row[0]: + db_conn.execute( + "REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)", + (row[0], False) + ) diff --git a/tests/utils.py b/tests/utils.py index dd19a16fc7..6eb575bd09 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,7 +16,7 @@ from synapse.http.server import HttpServer from synapse.api.errors import cs_error, CodeMessageException, StoreError from synapse.api.constants import EventTypes -from synapse.storage import prepare_database +from synapse.storage.schema_prepare import prepare_database from synapse.storage.engines import create_engine from synapse.server import HomeServer -- cgit 1.5.1 From 17c80c8a3d92acca5bda9b0fc7d9898547476563 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 13 Oct 2015 13:56:22 +0100 Subject: rename schema_prepare to prepare_database --- synapse/app/homeserver.py | 2 +- synapse/storage/engines/postgres.py | 2 +- synapse/storage/engines/sqlite3.py | 2 +- synapse/storage/prepare_database.py | 395 ++++++++++++++++++++++++++++++++++++ synapse/storage/schema_prepare.py | 395 ------------------------------------ tests/utils.py | 2 +- 6 files changed, 399 insertions(+), 399 deletions(-) create mode 100644 synapse/storage/prepare_database.py delete mode 100644 synapse/storage/schema_prepare.py (limited to 'tests') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b284d07cf0..af53acb369 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -36,7 +36,7 @@ if __name__ == '__main__': from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage import are_all_users_on_domain -from synapse.storage.schema_prepare import UpgradeDatabaseException +from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.server import HomeServer diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 7e45dabf4c..98d66e0a86 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.schema_prepare import prepare_database +from synapse.storage.prepare_database import prepare_database from ._base import IncorrectDatabaseSetup diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 0eeaa45d19..bad3b5c5ac 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.schema_prepare import ( +from synapse.storage.prepare_database import ( prepare_database, prepare_sqlite3_database ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py new file mode 100644 index 0000000000..1ddf55be4d --- /dev/null +++ b/synapse/storage/prepare_database.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fnmatch +import imp +import logging +import os +import re + + +logger = logging.getLogger(__name__) + + +# Remember to update this number every time a change is made to database +# schema files, so the users will be informed on server restarts. +SCHEMA_VERSION = 24 + +dir_path = os.path.abspath(os.path.dirname(__file__)) + + +def read_schema(path): + """ Read the named database schema. + + Args: + path: Path of the database schema. + Returns: + A string containing the database schema. + """ + with open(path) as schema_file: + return schema_file.read() + + +class PrepareDatabaseException(Exception): + pass + + +class UpgradeDatabaseException(PrepareDatabaseException): + pass + + +def prepare_database(db_conn, database_engine): + """Prepares a database for usage. Will either create all necessary tables + or upgrade from an older schema version. + """ + try: + cur = db_conn.cursor() + version_info = _get_or_create_schema_state(cur, database_engine) + + if version_info: + user_version, delta_files, upgraded = version_info + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine + ) + else: + _setup_new_database(cur, database_engine) + + # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + + cur.close() + db_conn.commit() + except: + db_conn.rollback() + raise + + +def _setup_new_database(cur, database_engine): + """Sets up the database by finding a base set of "full schemas" and then + applying any necessary deltas. + + The "full_schemas" directory has subdirectories named after versions. This + function searches for the highest version less than or equal to + `SCHEMA_VERSION` and executes all .sql files in that directory. + + The function will then apply all deltas for all versions after the base + version. + + Example directory structure: + + schema/ + delta/ + ... + full_schemas/ + 3/ + test.sql + ... + 11/ + foo.sql + bar.sql + ... + + In the example foo.sql and bar.sql would be run, and then any delta files + for versions strictly greater than 11. + """ + current_dir = os.path.join(dir_path, "schema", "full_schemas") + directory_entries = os.listdir(current_dir) + + valid_dirs = [] + pattern = re.compile(r"^\d+(\.sql)?$") + for filename in directory_entries: + match = pattern.match(filename) + abs_path = os.path.join(current_dir, filename) + if match and os.path.isdir(abs_path): + ver = int(match.group(0)) + if ver <= SCHEMA_VERSION: + valid_dirs.append((ver, abs_path)) + else: + logger.warn("Unexpected entry in 'full_schemas': %s", filename) + + if not valid_dirs: + raise PrepareDatabaseException( + "Could not find a suitable base set of full schemas" + ) + + max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) + + logger.debug("Initialising schema v%d", max_current_ver) + + directory_entries = os.listdir(sql_dir) + + for filename in fnmatch.filter(directory_entries, "*.sql"): + sql_loc = os.path.join(sql_dir, filename) + logger.debug("Applying schema %s", sql_loc) + executescript(cur, sql_loc) + + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)" + ), + (max_current_ver, False,) + ) + + _upgrade_existing_database( + cur, + current_version=max_current_ver, + applied_delta_files=[], + upgraded=False, + database_engine=database_engine, + ) + + +def _upgrade_existing_database(cur, current_version, applied_delta_files, + upgraded, database_engine): + """Upgrades an existing database. + + Delta files can either be SQL stored in *.sql files, or python modules + in *.py. + + There can be multiple delta files per version. Synapse will keep track of + which delta files have been applied, and will apply any that haven't been + even if there has been no version bump. This is useful for development + where orthogonal schema changes may happen on separate branches. + + Different delta files for the same version *must* be orthogonal and give + the same result when applied in any order. No guarantees are made on the + order of execution of these scripts. + + This is a no-op of current_version == SCHEMA_VERSION. + + Example directory structure: + + schema/ + delta/ + 11/ + foo.sql + ... + 12/ + foo.sql + bar.py + ... + full_schemas/ + ... + + In the example, if current_version is 11, then foo.sql will be run if and + only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in + some arbitrary order. + + Args: + cur (Cursor) + current_version (int): The current version of the schema. + applied_delta_files (list): A list of deltas that have already been + applied. + upgraded (bool): Whether the current version was generated by having + applied deltas or from full schema file. If `True` the function + will never apply delta files for the given `current_version`, since + the current_version wasn't generated by applying those delta files. + """ + + if current_version > SCHEMA_VERSION: + raise ValueError( + "Cannot use this database as it is too " + + "new for the server to understand" + ) + + start_ver = current_version + if not upgraded: + start_ver += 1 + + logger.debug("applied_delta_files: %s", applied_delta_files) + + for v in range(start_ver, SCHEMA_VERSION + 1): + logger.debug("Upgrading schema to v%d", v) + + delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) + + try: + directory_entries = os.listdir(delta_dir) + except OSError: + logger.exception("Could not open delta dir for version %d", v) + raise UpgradeDatabaseException( + "Could not open delta dir for version %d" % (v,) + ) + + directory_entries.sort() + for file_name in directory_entries: + relative_path = os.path.join(str(v), file_name) + logger.debug("Found file: %s", relative_path) + if relative_path in applied_delta_files: + continue + + absolute_path = os.path.join( + dir_path, "schema", "delta", relative_path, + ) + root_name, ext = os.path.splitext(file_name) + if ext == ".py": + # This is a python upgrade module. We need to import into some + # package and then execute its `run_upgrade` function. + module_name = "synapse.storage.v%d_%s" % ( + v, root_name + ) + with open(absolute_path) as python_file: + module = imp.load_source( + module_name, absolute_path, python_file + ) + logger.debug("Running script %s", relative_path) + module.run_upgrade(cur, database_engine) + elif ext == ".pyc": + # Sometimes .pyc files turn up anyway even though we've + # disabled their generation; e.g. from distribution package + # installers. Silently skip it + pass + elif ext == ".sql": + # A plain old .sql file, just read and execute it + logger.debug("Applying schema %s", relative_path) + executescript(cur, absolute_path) + else: + # Not a valid delta file. + logger.warn( + "Found directory entry that did not end in .py or" + " .sql: %s", + relative_path, + ) + continue + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_schema_deltas (version, file)" + " VALUES (?,?)", + ), + (v, relative_path) + ) + + cur.execute("DELETE FROM schema_version") + cur.execute( + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)", + ), + (v, True) + ) + + +def get_statements(f): + statement_buffer = "" + in_comment = False # If we're in a /* ... */ style comment + + for line in f: + line = line.strip() + + if in_comment: + # Check if this line contains an end to the comment + comments = line.split("*/", 1) + if len(comments) == 1: + continue + line = comments[1] + in_comment = False + + # Remove inline block comments + line = re.sub(r"/\*.*\*/", " ", line) + + # Does this line start a comment? + comments = line.split("/*", 1) + if len(comments) > 1: + line = comments[0] + in_comment = True + + # Deal with line comments + line = line.split("--", 1)[0] + line = line.split("//", 1)[0] + + # Find *all* semicolons. We need to treat first and last entry + # specially. + statements = line.split(";") + + # We must prepend statement_buffer to the first statement + first_statement = "%s %s" % ( + statement_buffer.strip(), + statements[0].strip() + ) + statements[0] = first_statement + + # Every entry, except the last, is a full statement + for statement in statements[:-1]: + yield statement.strip() + + # The last entry did *not* end in a semicolon, so we store it for the + # next semicolon we find + statement_buffer = statements[-1].strip() + + +def executescript(txn, schema_path): + with open(schema_path, 'r') as f: + for statement in get_statements(f): + txn.execute(statement) + + +def _get_or_create_schema_state(txn, database_engine): + # Bluntly try creating the schema_version tables. + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + executescript(txn, schema_path) + + txn.execute("SELECT version, upgraded FROM schema_version") + row = txn.fetchone() + current_version = int(row[0]) if row else None + upgraded = bool(row[1]) if row else None + + if current_version: + txn.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_schema_deltas WHERE version >= ?" + ), + (current_version,) + ) + applied_deltas = [d for d, in txn.fetchall()] + return current_version, applied_deltas, upgraded + + return None + + +def prepare_sqlite3_database(db_conn): + """This function should be called before `prepare_database` on sqlite3 + databases. + + Since we changed the way we store the current schema version and handle + updates to schemas, we need a way to upgrade from the old method to the + new. This only affects sqlite databases since they were the only ones + supported at the time. + """ + with db_conn: + schema_path = os.path.join( + dir_path, "schema", "schema_version.sql", + ) + create_schema = read_schema(schema_path) + db_conn.executescript(create_schema) + + c = db_conn.execute("SELECT * FROM schema_version") + rows = c.fetchall() + c.close() + + if not rows: + c = db_conn.execute("PRAGMA user_version") + row = c.fetchone() + c.close() + + if row and row[0]: + db_conn.execute( + "REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)", + (row[0], False) + ) diff --git a/synapse/storage/schema_prepare.py b/synapse/storage/schema_prepare.py deleted file mode 100644 index 1ddf55be4d..0000000000 --- a/synapse/storage/schema_prepare.py +++ /dev/null @@ -1,395 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import fnmatch -import imp -import logging -import os -import re - - -logger = logging.getLogger(__name__) - - -# Remember to update this number every time a change is made to database -# schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 24 - -dir_path = os.path.abspath(os.path.dirname(__file__)) - - -def read_schema(path): - """ Read the named database schema. - - Args: - path: Path of the database schema. - Returns: - A string containing the database schema. - """ - with open(path) as schema_file: - return schema_file.read() - - -class PrepareDatabaseException(Exception): - pass - - -class UpgradeDatabaseException(PrepareDatabaseException): - pass - - -def prepare_database(db_conn, database_engine): - """Prepares a database for usage. Will either create all necessary tables - or upgrade from an older schema version. - """ - try: - cur = db_conn.cursor() - version_info = _get_or_create_schema_state(cur, database_engine) - - if version_info: - user_version, delta_files, upgraded = version_info - _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine - ) - else: - _setup_new_database(cur, database_engine) - - # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) - - cur.close() - db_conn.commit() - except: - db_conn.rollback() - raise - - -def _setup_new_database(cur, database_engine): - """Sets up the database by finding a base set of "full schemas" and then - applying any necessary deltas. - - The "full_schemas" directory has subdirectories named after versions. This - function searches for the highest version less than or equal to - `SCHEMA_VERSION` and executes all .sql files in that directory. - - The function will then apply all deltas for all versions after the base - version. - - Example directory structure: - - schema/ - delta/ - ... - full_schemas/ - 3/ - test.sql - ... - 11/ - foo.sql - bar.sql - ... - - In the example foo.sql and bar.sql would be run, and then any delta files - for versions strictly greater than 11. - """ - current_dir = os.path.join(dir_path, "schema", "full_schemas") - directory_entries = os.listdir(current_dir) - - valid_dirs = [] - pattern = re.compile(r"^\d+(\.sql)?$") - for filename in directory_entries: - match = pattern.match(filename) - abs_path = os.path.join(current_dir, filename) - if match and os.path.isdir(abs_path): - ver = int(match.group(0)) - if ver <= SCHEMA_VERSION: - valid_dirs.append((ver, abs_path)) - else: - logger.warn("Unexpected entry in 'full_schemas': %s", filename) - - if not valid_dirs: - raise PrepareDatabaseException( - "Could not find a suitable base set of full schemas" - ) - - max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0]) - - logger.debug("Initialising schema v%d", max_current_ver) - - directory_entries = os.listdir(sql_dir) - - for filename in fnmatch.filter(directory_entries, "*.sql"): - sql_loc = os.path.join(sql_dir, filename) - logger.debug("Applying schema %s", sql_loc) - executescript(cur, sql_loc) - - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)" - ), - (max_current_ver, False,) - ) - - _upgrade_existing_database( - cur, - current_version=max_current_ver, - applied_delta_files=[], - upgraded=False, - database_engine=database_engine, - ) - - -def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine): - """Upgrades an existing database. - - Delta files can either be SQL stored in *.sql files, or python modules - in *.py. - - There can be multiple delta files per version. Synapse will keep track of - which delta files have been applied, and will apply any that haven't been - even if there has been no version bump. This is useful for development - where orthogonal schema changes may happen on separate branches. - - Different delta files for the same version *must* be orthogonal and give - the same result when applied in any order. No guarantees are made on the - order of execution of these scripts. - - This is a no-op of current_version == SCHEMA_VERSION. - - Example directory structure: - - schema/ - delta/ - 11/ - foo.sql - ... - 12/ - foo.sql - bar.py - ... - full_schemas/ - ... - - In the example, if current_version is 11, then foo.sql will be run if and - only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in - some arbitrary order. - - Args: - cur (Cursor) - current_version (int): The current version of the schema. - applied_delta_files (list): A list of deltas that have already been - applied. - upgraded (bool): Whether the current version was generated by having - applied deltas or from full schema file. If `True` the function - will never apply delta files for the given `current_version`, since - the current_version wasn't generated by applying those delta files. - """ - - if current_version > SCHEMA_VERSION: - raise ValueError( - "Cannot use this database as it is too " + - "new for the server to understand" - ) - - start_ver = current_version - if not upgraded: - start_ver += 1 - - logger.debug("applied_delta_files: %s", applied_delta_files) - - for v in range(start_ver, SCHEMA_VERSION + 1): - logger.debug("Upgrading schema to v%d", v) - - delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) - - try: - directory_entries = os.listdir(delta_dir) - except OSError: - logger.exception("Could not open delta dir for version %d", v) - raise UpgradeDatabaseException( - "Could not open delta dir for version %d" % (v,) - ) - - directory_entries.sort() - for file_name in directory_entries: - relative_path = os.path.join(str(v), file_name) - logger.debug("Found file: %s", relative_path) - if relative_path in applied_delta_files: - continue - - absolute_path = os.path.join( - dir_path, "schema", "delta", relative_path, - ) - root_name, ext = os.path.splitext(file_name) - if ext == ".py": - # This is a python upgrade module. We need to import into some - # package and then execute its `run_upgrade` function. - module_name = "synapse.storage.v%d_%s" % ( - v, root_name - ) - with open(absolute_path) as python_file: - module = imp.load_source( - module_name, absolute_path, python_file - ) - logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine) - elif ext == ".pyc": - # Sometimes .pyc files turn up anyway even though we've - # disabled their generation; e.g. from distribution package - # installers. Silently skip it - pass - elif ext == ".sql": - # A plain old .sql file, just read and execute it - logger.debug("Applying schema %s", relative_path) - executescript(cur, absolute_path) - else: - # Not a valid delta file. - logger.warn( - "Found directory entry that did not end in .py or" - " .sql: %s", - relative_path, - ) - continue - - # Mark as done. - cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_schema_deltas (version, file)" - " VALUES (?,?)", - ), - (v, relative_path) - ) - - cur.execute("DELETE FROM schema_version") - cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded)" - " VALUES (?,?)", - ), - (v, True) - ) - - -def get_statements(f): - statement_buffer = "" - in_comment = False # If we're in a /* ... */ style comment - - for line in f: - line = line.strip() - - if in_comment: - # Check if this line contains an end to the comment - comments = line.split("*/", 1) - if len(comments) == 1: - continue - line = comments[1] - in_comment = False - - # Remove inline block comments - line = re.sub(r"/\*.*\*/", " ", line) - - # Does this line start a comment? - comments = line.split("/*", 1) - if len(comments) > 1: - line = comments[0] - in_comment = True - - # Deal with line comments - line = line.split("--", 1)[0] - line = line.split("//", 1)[0] - - # Find *all* semicolons. We need to treat first and last entry - # specially. - statements = line.split(";") - - # We must prepend statement_buffer to the first statement - first_statement = "%s %s" % ( - statement_buffer.strip(), - statements[0].strip() - ) - statements[0] = first_statement - - # Every entry, except the last, is a full statement - for statement in statements[:-1]: - yield statement.strip() - - # The last entry did *not* end in a semicolon, so we store it for the - # next semicolon we find - statement_buffer = statements[-1].strip() - - -def executescript(txn, schema_path): - with open(schema_path, 'r') as f: - for statement in get_statements(f): - txn.execute(statement) - - -def _get_or_create_schema_state(txn, database_engine): - # Bluntly try creating the schema_version tables. - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - executescript(txn, schema_path) - - txn.execute("SELECT version, upgraded FROM schema_version") - row = txn.fetchone() - current_version = int(row[0]) if row else None - upgraded = bool(row[1]) if row else None - - if current_version: - txn.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_schema_deltas WHERE version >= ?" - ), - (current_version,) - ) - applied_deltas = [d for d, in txn.fetchall()] - return current_version, applied_deltas, upgraded - - return None - - -def prepare_sqlite3_database(db_conn): - """This function should be called before `prepare_database` on sqlite3 - databases. - - Since we changed the way we store the current schema version and handle - updates to schemas, we need a way to upgrade from the old method to the - new. This only affects sqlite databases since they were the only ones - supported at the time. - """ - with db_conn: - schema_path = os.path.join( - dir_path, "schema", "schema_version.sql", - ) - create_schema = read_schema(schema_path) - db_conn.executescript(create_schema) - - c = db_conn.execute("SELECT * FROM schema_version") - rows = c.fetchall() - c.close() - - if not rows: - c = db_conn.execute("PRAGMA user_version") - row = c.fetchone() - c.close() - - if row and row[0]: - db_conn.execute( - "REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (row[0], False) - ) diff --git a/tests/utils.py b/tests/utils.py index 6eb575bd09..4da51291a4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,7 +16,7 @@ from synapse.http.server import HttpServer from synapse.api.errors import cs_error, CodeMessageException, StoreError from synapse.api.constants import EventTypes -from synapse.storage.schema_prepare import prepare_database +from synapse.storage.prepare_database import prepare_database from synapse.storage.engines import create_engine from synapse.server import HomeServer -- cgit 1.5.1 From 889778155811277585debda837c359a4ae471706 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 13 Oct 2015 14:13:51 +0100 Subject: update filtering tests --- tests/api/test_filtering.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 65b2f590c8..6942cdac51 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -345,9 +345,9 @@ class FilteringTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_filter_public_user_data_match(self): + def test_filter_presence_match(self): user_filter_json = { - "public_user_data": { + "presence": { "types": ["m.*"] } } @@ -368,13 +368,13 @@ class FilteringTestCase(unittest.TestCase): filter_id=filter_id, ) - results = user_filter.filter_public_user_data(events=events) + results = user_filter.filter_presence(events=events) self.assertEquals(events, results) @defer.inlineCallbacks - def test_filter_public_user_data_no_match(self): + def test_filter_presence_no_match(self): user_filter_json = { - "public_user_data": { + "presence": { "types": ["m.*"] } } @@ -395,7 +395,7 @@ class FilteringTestCase(unittest.TestCase): filter_id=filter_id, ) - results = user_filter.filter_public_user_data(events=events) + results = user_filter.filter_presence(events=events) self.assertEquals([], results) @defer.inlineCallbacks -- cgit 1.5.1 From aff4d850bdc5d6108b1f6f84591b44db6e496d75 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Fri, 16 Oct 2015 19:56:46 +0100 Subject: Add some unit tests of prune_events() --- tests/events/__init__.py | 0 tests/events/test_utils.py | 115 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+) create mode 100644 tests/events/__init__.py create mode 100644 tests/events/test_utils.py (limited to 'tests') diff --git a/tests/events/__init__.py b/tests/events/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py new file mode 100644 index 0000000000..16179921f0 --- /dev/null +++ b/tests/events/test_utils.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .. import unittest + +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event + +class PruneEventTestCase(unittest.TestCase): + """ Asserts that a new event constructed with `evdict` will look like + `matchdict` when it is redacted. """ + def run_test(self, evdict, matchdict): + self.assertEquals( + prune_event(FrozenEvent(evdict)).get_dict(), + matchdict + ) + + def test_minimal(self): + self.run_test( + {'type': 'A'}, + { + 'type': 'A', + 'content': {}, + 'signatures': {}, + 'unsigned': {}, + } + ) + + def test_basic_keys(self): + self.run_test( + { + 'type': 'A', + 'room_id': '!1:domain', + 'sender': '@2:domain', + 'event_id': '$3:domain', + 'origin': 'domain', + }, + { + 'type': 'A', + 'room_id': '!1:domain', + 'sender': '@2:domain', + 'event_id': '$3:domain', + 'origin': 'domain', + 'content': {}, + 'signatures': {}, + 'unsigned': {}, + } + ) + + def test_unsigned_age_ts(self): + self.run_test( + { + 'type': 'B', + 'unsigned': {'age_ts': 20}, + }, + { + 'type': 'B', + 'content': {}, + 'signatures': {}, + 'unsigned': {'age_ts': 20}, + } + ) + + self.run_test( + { + 'type': 'B', + 'unsigned': {'other_key': 'here'}, + }, + { + 'type': 'B', + 'content': {}, + 'signatures': {}, + 'unsigned': {}, + } + ) + + def test_content(self): + self.run_test( + { + 'type': 'C', + 'content': {'things': 'here'}, + }, + { + 'type': 'C', + 'content': {}, + 'signatures': {}, + 'unsigned': {}, + } + ) + + self.run_test( + { + 'type': 'm.room.create', + 'content': {'creator': '@2:domain', 'other_field': 'here'}, + }, + { + 'type': 'm.room.create', + 'content': {'creator': '@2:domain'}, + 'signatures': {}, + 'unsigned': {}, + } + ) -- cgit 1.5.1 From 0aab34004b2e56c3ab79f514be264c568ad71fd3 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Mon, 19 Oct 2015 14:40:15 +0100 Subject: Initial minimial hack at a test of event hashing and signing --- tests/crypto/test_event_signing.py | 98 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 tests/crypto/test_event_signing.py (limited to 'tests') diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py new file mode 100644 index 0000000000..0b560e9317 --- /dev/null +++ b/tests/crypto/test_event_signing.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests import unittest +from tests.utils import MockClock + +from synapse.events.builder import EventBuilderFactory +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.types import EventID + +from unpaddedbase64 import decode_base64 + +import nacl.signing + + +# Perform these tests using given secret key so we get entirely deterministic +# signatures output that we can test against. +SIGNING_KEY_SEED = decode_base64( + "YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1" +) + +KEY_ALG = "ed25519" +KEY_VER = 1 +KEY_NAME = "%s:%d" % (KEY_ALG, KEY_VER) + +HOSTNAME = "domain" + + +class EventBuilderFactoryWithPredicableIDs(EventBuilderFactory): + """ A subclass of EventBuilderFactory that generates entirely predicatable + event IDs, so we can assert on them. """ + def create_event_id(self): + i = str(self.event_id_count) + self.event_id_count += 1 + + return EventID.create(i, self.hostname).to_string() + + +class EventSigningTestCase(unittest.TestCase): + + def setUp(self): + self.event_builder_factory = EventBuilderFactoryWithPredicableIDs( + clock=MockClock(), + hostname=HOSTNAME, + ) + + self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED) + self.signing_key.alg = KEY_ALG + self.signing_key.version = KEY_VER + + def test_sign(self): + builder = self.event_builder_factory.new( + {'type': "X"} + ) + self.assertEquals( + builder.build().get_dict(), + { + 'event_id': "$0:domain", + 'origin': "domain", + 'origin_server_ts': 1000000, + 'signatures': {}, + 'type': "X", + 'unsigned': {'age_ts': 1000000}, + }, + ) + + add_hashes_and_signatures(builder, HOSTNAME, self.signing_key) + + event = builder.build() + + self.assertTrue(hasattr(event, 'hashes')) + self.assertTrue('sha256' in event.hashes) + self.assertEquals( + event.hashes['sha256'], + "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI", + ) + + self.assertTrue(hasattr(event, 'signatures')) + self.assertTrue(HOSTNAME in event.signatures) + self.assertTrue(KEY_NAME in event.signatures["domain"]) + self.assertEquals( + event.signatures[HOSTNAME][KEY_NAME], + "2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+" + "aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA", + ) -- cgit 1.5.1 From 07b58a431f9e0367f8c08d2bc8983473c8a0c379 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Mon, 19 Oct 2015 15:00:52 +0100 Subject: Another signing test vector using an 'm.room.message' with content, so that the implementation will have to redact it --- tests/crypto/test_event_signing.py | 50 +++++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 0b560e9317..0f487d9c7b 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -61,7 +61,7 @@ class EventSigningTestCase(unittest.TestCase): self.signing_key.alg = KEY_ALG self.signing_key.version = KEY_VER - def test_sign(self): + def test_sign_minimal(self): builder = self.event_builder_factory.new( {'type': "X"} ) @@ -96,3 +96,51 @@ class EventSigningTestCase(unittest.TestCase): "2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+" "aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA", ) + + def test_sign_message(self): + builder = self.event_builder_factory.new( + { + 'type': "m.room.message", + 'sender': "@u:domain", + 'room_id': "!r:domain", + 'content': { + 'body': "Here is the message content", + }, + } + ) + self.assertEquals( + builder.build().get_dict(), + { + 'content': { + 'body': "Here is the message content", + }, + 'event_id': "$0:domain", + 'origin': "domain", + 'origin_server_ts': 1000000, + 'type': "m.room.message", + 'room_id': "!r:domain", + 'sender': "@u:domain", + 'signatures': {}, + 'unsigned': {'age_ts': 1000000}, + } + ) + + add_hashes_and_signatures(builder, HOSTNAME, self.signing_key) + + event = builder.build() + + self.assertTrue(hasattr(event, 'hashes')) + self.assertTrue('sha256' in event.hashes) + self.assertEquals( + event.hashes['sha256'], + "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g", + ) + + self.assertTrue(hasattr(event, 'signatures')) + self.assertTrue(HOSTNAME in event.signatures) + self.assertTrue(KEY_NAME in event.signatures["domain"]) + self.assertEquals( + event.signatures[HOSTNAME][KEY_NAME], + "Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw" + "u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA" + ) -- cgit 1.5.1 From a8795c9644d555e95a6be3211b4e79e447087697 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Mon, 19 Oct 2015 15:24:49 +0100 Subject: Use assertIn() instead of assertTrue on the 'in' operator --- tests/crypto/test_event_signing.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 0f487d9c7b..010fe4ed33 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -82,15 +82,15 @@ class EventSigningTestCase(unittest.TestCase): event = builder.build() self.assertTrue(hasattr(event, 'hashes')) - self.assertTrue('sha256' in event.hashes) + self.assertIn('sha256', event.hashes) self.assertEquals( event.hashes['sha256'], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI", ) self.assertTrue(hasattr(event, 'signatures')) - self.assertTrue(HOSTNAME in event.signatures) - self.assertTrue(KEY_NAME in event.signatures["domain"]) + self.assertIn(HOSTNAME, event.signatures) + self.assertIn(KEY_NAME, event.signatures["domain"]) self.assertEquals( event.signatures[HOSTNAME][KEY_NAME], "2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+" @@ -130,15 +130,15 @@ class EventSigningTestCase(unittest.TestCase): event = builder.build() self.assertTrue(hasattr(event, 'hashes')) - self.assertTrue('sha256' in event.hashes) + self.assertIn('sha256', event.hashes) self.assertEquals( event.hashes['sha256'], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g", ) self.assertTrue(hasattr(event, 'signatures')) - self.assertTrue(HOSTNAME in event.signatures) - self.assertTrue(KEY_NAME in event.signatures["domain"]) + self.assertIn(HOSTNAME, event.signatures) + self.assertIn(KEY_NAME, event.signatures["domain"]) self.assertEquals( event.signatures[HOSTNAME][KEY_NAME], "Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw" -- cgit 1.5.1 From 531e3aa75effdec137c1ffbdb1fb0e8cb0cbe40e Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Mon, 19 Oct 2015 17:37:35 +0100 Subject: Capture __init__.py --- tests/crypto/__init__.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 tests/crypto/__init__.py (limited to 'tests') diff --git a/tests/crypto/__init__.py b/tests/crypto/__init__.py new file mode 100644 index 0000000000..9bff9ec169 --- /dev/null +++ b/tests/crypto/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + -- cgit 1.5.1 From 9ed784098a94cf80d2582cc1d98484ac9d748eee Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Mon, 19 Oct 2015 17:42:34 +0100 Subject: Invoke EventBuilder directly instead of going via the EventBuilderFactory --- tests/crypto/test_event_signing.py | 38 +++----------------------------------- 1 file changed, 3 insertions(+), 35 deletions(-) (limited to 'tests') diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 010fe4ed33..7913472941 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -15,11 +15,9 @@ from tests import unittest -from tests.utils import MockClock -from synapse.events.builder import EventBuilderFactory +from synapse.events.builder import EventBuilder from synapse.crypto.event_signing import add_hashes_and_signatures -from synapse.types import EventID from unpaddedbase64 import decode_base64 @@ -39,34 +37,15 @@ KEY_NAME = "%s:%d" % (KEY_ALG, KEY_VER) HOSTNAME = "domain" -class EventBuilderFactoryWithPredicableIDs(EventBuilderFactory): - """ A subclass of EventBuilderFactory that generates entirely predicatable - event IDs, so we can assert on them. """ - def create_event_id(self): - i = str(self.event_id_count) - self.event_id_count += 1 - - return EventID.create(i, self.hostname).to_string() - - class EventSigningTestCase(unittest.TestCase): def setUp(self): - self.event_builder_factory = EventBuilderFactoryWithPredicableIDs( - clock=MockClock(), - hostname=HOSTNAME, - ) - self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED) self.signing_key.alg = KEY_ALG self.signing_key.version = KEY_VER def test_sign_minimal(self): - builder = self.event_builder_factory.new( - {'type': "X"} - ) - self.assertEquals( - builder.build().get_dict(), + builder = EventBuilder( { 'event_id': "$0:domain", 'origin': "domain", @@ -98,18 +77,7 @@ class EventSigningTestCase(unittest.TestCase): ) def test_sign_message(self): - builder = self.event_builder_factory.new( - { - 'type': "m.room.message", - 'sender': "@u:domain", - 'room_id': "!r:domain", - 'content': { - 'body': "Here is the message content", - }, - } - ) - self.assertEquals( - builder.build().get_dict(), + builder = EventBuilder( { 'content': { 'body': "Here is the message content", -- cgit 1.5.1 From 137fafce4ee06e76b05d37807611e10055059f62 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 20 Oct 2015 11:58:58 +0100 Subject: Allow rejecting invites This is done by using the same /leave flow as you would use if you had already accepted the invite and wanted to leave. --- synapse/api/auth.py | 6 +- synapse/federation/federation_client.py | 67 +++++++++- synapse/federation/federation_server.py | 14 +++ synapse/federation/transport/client.py | 24 +++- synapse/federation/transport/server.py | 20 +++ synapse/handlers/federation.py | 209 +++++++++++++++++++++++++------- synapse/handlers/room.py | 102 +++++++++------- tests/rest/client/v1/test_rooms.py | 4 +- 8 files changed, 353 insertions(+), 93 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index cf19eda4e9..494c8ac3d4 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -308,7 +308,11 @@ class Auth(object): ) if Membership.JOIN != membership: - # JOIN is the only action you can perform if you're not in the room + if (caller_invited + and Membership.LEAVE == membership + and target_user_id == event.user_id): + return True + if not caller_in_room: # caller isn't joined raise AuthError( 403, diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index f5b430e046..723f571284 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -17,6 +17,7 @@ from twisted.internet import defer from .federation_base import FederationBase +from synapse.api.constants import Membership from .units import Edu from synapse.api.errors import ( @@ -357,7 +358,34 @@ class FederationClient(FederationBase): defer.returnValue(signed_auth) @defer.inlineCallbacks - def make_join(self, destinations, room_id, user_id, content): + def make_membership_event(self, destinations, room_id, user_id, membership, content): + """ + Creates an m.room.member event, with context, without participating in the room. + + Does so by asking one of the already participating servers to create an + event with proper context. + + Note that this does not append any events to any graphs. + + Args: + destinations (str): Candidate homeservers which are probably + participating in the room. + room_id (str): The room in which the event will happen. + user_id (str): The user whose membership is being evented. + membership (str): The "membership" property of the event. Must be + one of "join" or "leave". + content (object): Any additional data to put into the content field + of the event. + Return: + A tuple of (origin (str), event (object)) where origin is the remote + homeserver which generated the event. + """ + valid_memberships = {Membership.JOIN, Membership.LEAVE} + if membership not in valid_memberships: + raise RuntimeError( + "make_membership_event called with membership='%s', must be one of %s" % + (membership, ",".join(valid_memberships)) + ) for destination in destinations: if destination == self.server_name: continue @@ -368,13 +396,13 @@ class FederationClient(FederationBase): content["third_party_invite"] ) try: - ret = yield self.transport_layer.make_join( - destination, room_id, user_id, args + ret = yield self.transport_layer.make_membership_event( + destination, room_id, user_id, membership, args ) pdu_dict = ret["event"] - logger.debug("Got response to make_join: %s", pdu_dict) + logger.debug("Got response to make_%s: %s", membership, pdu_dict) defer.returnValue( (destination, self.event_from_pdu_json(pdu_dict)) @@ -384,8 +412,8 @@ class FederationClient(FederationBase): raise except Exception as e: logger.warn( - "Failed to make_join via %s: %s", - destination, e.message + "Failed to make_%s via %s: %s", + membership, destination, e.message ) raise RuntimeError("Failed to send to any server.") @@ -491,6 +519,33 @@ class FederationClient(FederationBase): defer.returnValue(pdu) + @defer.inlineCallbacks + def send_leave(self, destinations, pdu): + for destination in destinations: + if destination == self.server_name: + continue + + try: + time_now = self._clock.time_msec() + _, content = yield self.transport_layer.send_leave( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + logger.debug("Got content: %s", content) + defer.returnValue(None) + except CodeMessageException: + raise + except Exception as e: + logger.exception( + "Failed to send_leave via %s: %s", + destination, e.message + ) + + raise RuntimeError("Failed to send to any server.") + @defer.inlineCallbacks def query_auth(self, destination, room_id, event_id, local_auth): """ diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 7934f740e0..9e2d9ee74c 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -267,6 +267,20 @@ class FederationServer(FederationBase): ], })) + @defer.inlineCallbacks + def on_make_leave_request(self, room_id, user_id): + pdu = yield self.handler.on_make_leave_request(room_id, user_id) + time_now = self._clock.time_msec() + defer.returnValue({"event": pdu.get_pdu_json(time_now)}) + + @defer.inlineCallbacks + def on_send_leave_request(self, origin, content): + logger.debug("on_send_leave_request: content: %s", content) + pdu = self.event_from_pdu_json(content) + logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) + yield self.handler.on_send_leave_request(origin, pdu) + defer.returnValue((200, {})) + @defer.inlineCallbacks def on_event_auth(self, origin, room_id, event_id): time_now = self._clock.time_msec() diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index ae4195e83a..a81b3c4345 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -14,6 +14,7 @@ # limitations under the License. from twisted.internet import defer +from synapse.api.constants import Membership from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.util.logutils import log_function @@ -160,8 +161,14 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def make_join(self, destination, room_id, user_id, args={}): - path = PREFIX + "/make_join/%s/%s" % (room_id, user_id) + def make_membership_event(self, destination, room_id, user_id, membership, args={}): + valid_memberships = {Membership.JOIN, Membership.LEAVE} + if membership not in valid_memberships: + raise RuntimeError( + "make_membership_event called with membership='%s', must be one of %s" % + (membership, ",".join(valid_memberships)) + ) + path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id) content = yield self.client.get_json( destination=destination, @@ -185,6 +192,19 @@ class TransportLayerClient(object): defer.returnValue(response) + @defer.inlineCallbacks + @log_function + def send_leave(self, destination, room_id, event_id, content): + path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id) + + response = yield self.client.put_json( + destination=destination, + path=path, + data=content, + ) + + defer.returnValue(response) + @defer.inlineCallbacks @log_function def send_invite(self, destination, room_id, event_id, content): diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 6e394f039e..8184159210 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -296,6 +296,24 @@ class FederationMakeJoinServlet(BaseFederationServlet): defer.returnValue((200, content)) +class FederationMakeLeaveServlet(BaseFederationServlet): + PATH = "/make_leave/([^/]*)/([^/]*)" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, context, user_id): + content = yield self.handler.on_make_leave_request(context, user_id) + defer.returnValue((200, content)) + + +class FederationSendLeaveServlet(BaseFederationServlet): + PATH = "/send_leave/([^/]*)/([^/]*)" + + @defer.inlineCallbacks + def on_PUT(self, origin, content, query, room_id, txid): + content = yield self.handler.on_send_leave_request(origin, content) + defer.returnValue((200, content)) + + class FederationEventAuthServlet(BaseFederationServlet): PATH = "/event_auth/([^/]*)/([^/]*)" @@ -385,8 +403,10 @@ SERVLET_CLASSES = ( FederationBackfillServlet, FederationQueryServlet, FederationMakeJoinServlet, + FederationMakeLeaveServlet, FederationEventServlet, FederationSendJoinServlet, + FederationSendLeaveServlet, FederationInviteServlet, FederationQueryAuthServlet, FederationGetMissingEventsServlet, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 946ff97c7d..ae9d227586 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -565,7 +565,7 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def do_invite_join(self, target_hosts, room_id, joinee, content, snapshot): + def do_invite_join(self, target_hosts, room_id, joinee, content): """ Attempts to join the `joinee` to the room `room_id` via the server `target_host`. @@ -581,50 +581,19 @@ class FederationHandler(BaseHandler): yield self.store.clean_room_for_join(room_id) - origin, pdu = yield self.replication_layer.make_join( + origin, event = yield self._make_and_verify_event( target_hosts, room_id, joinee, + "join", content ) - logger.debug("Got response to make_join: %s", pdu) - - event = pdu - - # We should assert some things. - # FIXME: Do this in a nicer way - assert(event.type == EventTypes.Member) - assert(event.user_id == joinee) - assert(event.state_key == joinee) - assert(event.room_id == room_id) - - event.internal_metadata.outlier = False - self.room_queues[room_id] = [] - - builder = self.event_builder_factory.new( - unfreeze(event.get_pdu_json()) - ) - handled_events = set() try: - builder.event_id = self.event_builder_factory.create_event_id() - builder.origin = self.hs.hostname - builder.content = content - - if not hasattr(event, "signatures"): - builder.signatures = {} - - add_hashes_and_signatures( - builder, - self.hs.hostname, - self.hs.config.signing_key[0], - ) - - new_event = builder.build() - + new_event = self._sign_event(event) # Try the host we successfully got a response to /make_join/ # request first. try: @@ -632,11 +601,7 @@ class FederationHandler(BaseHandler): target_hosts.insert(0, origin) except ValueError: pass - - ret = yield self.replication_layer.send_join( - target_hosts, - new_event - ) + ret = yield self.replication_layer.send_join(target_hosts, new_event) origin = ret["origin"] state = ret["state"] @@ -700,7 +665,7 @@ class FederationHandler(BaseHandler): @log_function def on_make_join_request(self, room_id, user_id, query): """ We've received a /make_join/ request, so we create a partial - join event for the room and return that. We don *not* persist or + join event for the room and return that. We do *not* persist or process it until the other server has signed it and sent it back. """ event_content = {"membership": Membership.JOIN} @@ -859,6 +824,168 @@ class FederationHandler(BaseHandler): defer.returnValue(event) + @defer.inlineCallbacks + def do_remotely_reject_invite(self, target_hosts, room_id, user_id): + origin, event = yield self._make_and_verify_event( + target_hosts, + room_id, + user_id, + "leave", + {} + ) + signed_event = self._sign_event(event) + + # Try the host we successfully got a response to /make_join/ + # request first. + try: + target_hosts.remove(origin) + target_hosts.insert(0, origin) + except ValueError: + pass + + yield self.replication_layer.send_leave( + target_hosts, + signed_event + ) + defer.returnValue(None) + + @defer.inlineCallbacks + def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, content): + origin, pdu = yield self.replication_layer.make_membership_event( + target_hosts, + room_id, + user_id, + membership, + content + ) + + logger.debug("Got response to make_%s: %s", membership, pdu) + + event = pdu + + # We should assert some things. + # FIXME: Do this in a nicer way + assert(event.type == EventTypes.Member) + assert(event.user_id == user_id) + assert(event.state_key == user_id) + assert(event.room_id == room_id) + defer.returnValue((origin, event)) + + def _sign_event(self, event): + event.internal_metadata.outlier = False + + builder = self.event_builder_factory.new( + unfreeze(event.get_pdu_json()) + ) + + builder.event_id = self.event_builder_factory.create_event_id() + builder.origin = self.hs.hostname + + if not hasattr(event, "signatures"): + builder.signatures = {} + + add_hashes_and_signatures( + builder, + self.hs.hostname, + self.hs.config.signing_key[0], + ) + + return builder.build() + + @defer.inlineCallbacks + @log_function + def on_make_leave_request(self, room_id, user_id): + """ We've received a /make_leave/ request, so we create a partial + join event for the room and return that. We do *not* persist or + process it until the other server has signed it and sent it back. + """ + builder = self.event_builder_factory.new({ + "type": EventTypes.Member, + "content": {"membership": Membership.LEAVE}, + "room_id": room_id, + "sender": user_id, + "state_key": user_id, + }) + + event, context = yield self._create_new_client_event( + builder=builder, + ) + + self.auth.check(event, auth_events=context.current_state) + + defer.returnValue(event) + + @defer.inlineCallbacks + @log_function + def on_send_leave_request(self, origin, pdu): + """ We have received a leave event for a room. Fully process it.""" + event = pdu + + logger.debug( + "on_send_leave_request: Got event: %s, signatures: %s", + event.event_id, + event.signatures, + ) + + event.internal_metadata.outlier = False + + context, event_stream_id, max_stream_id = yield self._handle_new_event( + origin, event + ) + + logger.debug( + "on_send_leave_request: After _handle_new_event: %s, sigs: %s", + event.event_id, + event.signatures, + ) + + extra_users = [] + if event.type == EventTypes.Member: + target_user_id = event.state_key + target_user = UserID.from_string(target_user_id) + extra_users.append(target_user) + + with PreserveLoggingContext(): + d = self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, extra_users=extra_users + ) + + def log_failure(f): + logger.warn( + "Failed to notify about %s: %s", + event.event_id, f.value + ) + + d.addErrback(log_failure) + + new_pdu = event + + destinations = set() + + for k, s in context.current_state.items(): + try: + if k[0] == EventTypes.Member: + if s.content["membership"] == Membership.LEAVE: + destinations.add( + UserID.from_string(s.state_key).domain + ) + except: + logger.warn( + "Failed to get destination from event %s", s.event_id + ) + + destinations.discard(origin) + + logger.debug( + "on_send_leave_request: Sending event: %s, signatures: %s", + event.event_id, + event.signatures, + ) + + self.replication_layer.send_pdu(new_pdu, destinations) + + defer.returnValue(None) + @defer.inlineCallbacks def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True): yield run_on_reactor() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3f0cde56f0..60f9fa58b0 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -389,7 +389,22 @@ class RoomMemberHandler(BaseHandler): if event.membership == Membership.JOIN: yield self._do_join(event, context, do_auth=do_auth) else: - # This is not a JOIN, so we can handle it normally. + if event.membership == Membership.LEAVE: + is_host_in_room = yield self.is_host_in_room(room_id, context) + if not is_host_in_room: + # Rejecting an invite, rather than leaving a joined room + handler = self.hs.get_handlers().federation_handler + inviter = yield self.get_inviter(event) + if not inviter: + # return the same error as join_room_alias does + raise SynapseError(404, "No known servers") + yield handler.do_remotely_reject_invite( + [inviter.domain], + room_id, + event.user_id + ) + defer.returnValue({"room_id": room_id}) + return # FIXME: This isn't idempotency. if prev_state and prev_state.membership == event.membership: @@ -413,7 +428,7 @@ class RoomMemberHandler(BaseHandler): defer.returnValue({"room_id": room_id}) @defer.inlineCallbacks - def join_room_alias(self, joinee, room_alias, do_auth=True, content={}): + def join_room_alias(self, joinee, room_alias, content={}): directory_handler = self.hs.get_handlers().directory_handler mapping = yield directory_handler.get_association(room_alias) @@ -447,8 +462,6 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def _do_join(self, event, context, room_hosts=None, do_auth=True): - joinee = UserID.from_string(event.state_key) - # room_id = RoomID.from_string(event.room_id, self.hs) room_id = event.room_id # XXX: We don't do an auth check if we are doing an invite @@ -456,48 +469,18 @@ class RoomMemberHandler(BaseHandler): # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. - is_host_in_room = yield self.auth.check_host_in_room( - event.room_id, - self.hs.hostname - ) - if not is_host_in_room: - # is *anyone* in the room? - room_member_keys = [ - v for (k, v) in context.current_state.keys() if ( - k == "m.room.member" - ) - ] - if len(room_member_keys) == 0: - # has the room been created so we can join it? - create_event = context.current_state.get(("m.room.create", "")) - if create_event: - is_host_in_room = True - + is_host_in_room = yield self.is_host_in_room(room_id, context) if is_host_in_room: should_do_dance = False elif room_hosts: # TODO: Shouldn't this be remote_room_host? should_do_dance = True else: - # TODO(markjh): get prev_state from snapshot - prev_state = yield self.store.get_room_member( - joinee.to_string(), room_id - ) - - if prev_state and prev_state.membership == Membership.INVITE: - inviter = UserID.from_string(prev_state.user_id) - - should_do_dance = not self.hs.is_mine(inviter) - room_hosts = [inviter.domain] - elif "third_party_invite" in event.content: - if "sender" in event.content["third_party_invite"]: - inviter = UserID.from_string( - event.content["third_party_invite"]["sender"] - ) - should_do_dance = not self.hs.is_mine(inviter) - room_hosts = [inviter.domain] - else: + inviter = yield self.get_inviter(event) + if not inviter: # return the same error as join_room_alias does raise SynapseError(404, "No known servers") + should_do_dance = not self.hs.is_mine(inviter) + room_hosts = [inviter.domain] if should_do_dance: handler = self.hs.get_handlers().federation_handler @@ -505,8 +488,7 @@ class RoomMemberHandler(BaseHandler): room_hosts, room_id, event.user_id, - event.content, # FIXME To get a non-frozen dict - context + event.content # FIXME To get a non-frozen dict ) else: logger.debug("Doing normal join") @@ -523,6 +505,44 @@ class RoomMemberHandler(BaseHandler): "user_joined_room", user=user, room_id=room_id ) + @defer.inlineCallbacks + def get_inviter(self, event): + # TODO(markjh): get prev_state from snapshot + prev_state = yield self.store.get_room_member( + event.user_id, event.room_id + ) + + if prev_state and prev_state.membership == Membership.INVITE: + defer.returnValue(UserID.from_string(prev_state.user_id)) + return + elif "third_party_invite" in event.content: + if "sender" in event.content["third_party_invite"]: + inviter = UserID.from_string( + event.content["third_party_invite"]["sender"] + ) + defer.returnValue(inviter) + defer.returnValue(None) + + @defer.inlineCallbacks + def is_host_in_room(self, room_id, context): + is_host_in_room = yield self.auth.check_host_in_room( + room_id, + self.hs.hostname + ) + if not is_host_in_room: + # is *anyone* in the room? + room_member_keys = [ + v for (k, v) in context.current_state.keys() if ( + k == "m.room.member" + ) + ] + if len(room_member_keys) == 0: + # has the room been created so we can join it? + create_event = context.current_state.get(("m.room.create", "")) + if create_event: + is_host_in_room = True + defer.returnValue(is_host_in_room) + @defer.inlineCallbacks def get_joined_rooms_for_user(self, user): """Returns a list of roomids that the user has any of the given diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index a2123be81b..93896dd076 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -277,10 +277,10 @@ class RoomPermissionsTestCase(RestTestCase): expect_code=403) # set [invite/join/left] of self, set [invite/join/left] of other, - # expect all 403s + # expect all 404s because room doesn't exist on any server for usr in [self.user_id, self.rmcreator_id]: yield self.join(room=room, user=usr, expect_code=404) - yield self.leave(room=room, user=usr, expect_code=403) + yield self.leave(room=room, user=usr, expect_code=404) @defer.inlineCallbacks def test_membership_private_room_perms(self): -- cgit 1.5.1 From 45cd2b023399dc79a77cf59a356ed1c130d970d2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 20 Oct 2015 15:33:25 +0100 Subject: Refactor api.filtering to have a Filter API --- synapse/api/filtering.py | 153 +++++++++++++---------------------- synapse/rest/client/v2_alpha/sync.py | 4 +- tests/api/test_filtering.py | 57 +++++++------ 3 files changed, 88 insertions(+), 126 deletions(-) (limited to 'tests') diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index e79e91e7eb..cd7a465e97 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -24,7 +24,7 @@ class Filtering(object): def get_user_filter(self, user_localpart, filter_id): result = self.store.get_user_filter(user_localpart, filter_id) - result.addCallback(Filter) + result.addCallback(FilterCollection) return result def add_user_filter(self, user_localpart, user_filter): @@ -131,125 +131,82 @@ class Filtering(object): raise SynapseError(400, "Bad bundle_updates: expected bool.") -class Filter(object): +class FilterCollection(object): def __init__(self, filter_json): self.filter_json = filter_json + self.room_timeline_filter = Filter( + self.filter_json.get("room", {}).get("timeline", {}) + ) + + self.room_state_filter = Filter( + self.filter_json.get("room", {}).get("state", {}) + ) + + self.room_ephemeral_filter = Filter( + self.filter_json.get("room", {}).get("ephemeral", {}) + ) + + self.presence_filter = Filter( + self.filter_json.get("presence", {}) + ) + def timeline_limit(self): - return self.filter_json.get("room", {}).get("timeline", {}).get("limit", 10) + return self.room_timeline_filter.limit() def presence_limit(self): - return self.filter_json.get("presence", {}).get("limit", 10) + return self.presence_filter.limit() def ephemeral_limit(self): - return self.filter_json.get("room", {}).get("ephemeral", {}).get("limit", 10) + return self.room_ephemeral_filter.limit() def filter_presence(self, events): - return self._filter_on_key(events, ["presence"]) + return self.presence_filter.filter(events) def filter_room_state(self, events): - return self._filter_on_key(events, ["room", "state"]) + return self.room_state_filter.filter(events) def filter_room_timeline(self, events): - return self._filter_on_key(events, ["room", "timeline"]) + return self.room_timeline_filter.filter(events) def filter_room_ephemeral(self, events): - return self._filter_on_key(events, ["room", "ephemeral"]) - - def _filter_on_key(self, events, keys): - filter_json = self.filter_json - if not filter_json: - return events - - try: - # extract the right definition from the filter - definition = filter_json - for key in keys: - definition = definition[key] - return self._filter_with_definition(events, definition) - except KeyError: - # return all events if definition isn't specified. - return events - - def _filter_with_definition(self, events, definition): - return [e for e in events if self._passes_definition(definition, e)] - - def _passes_definition(self, definition, event): - """Check if the event passes the filter definition - Args: - definition(dict): The filter definition to check against - event(dict or Event): The event to check - Returns: - True if the event passes the filter in the definition - """ - if type(event) is dict: - room_id = event.get("room_id") - sender = event.get("sender") - event_type = event["type"] - else: - room_id = getattr(event, "room_id", None) - sender = getattr(event, "sender", None) - event_type = event.type - return self._event_passes_definition( - definition, room_id, sender, event_type - ) + return self.room_ephemeral_filter.filter(events) - def _event_passes_definition(self, definition, room_id, sender, - event_type): - """Check if the event passes through the given definition. - Args: - definition(dict): The definition to check against. - room_id(str): The id of the room this event is in or None. - sender(str): The sender of the event - event_type(str): The type of the event. - Returns: - True if the event passes through the filter. - """ - # Algorithm notes: - # For each key in the definition, check the event meets the criteria: - # * For types: Literal match or prefix match (if ends with wildcard) - # * For senders/rooms: Literal match only - # * "not_" checks take presedence (e.g. if "m.*" is in both 'types' - # and 'not_types' then it is treated as only being in 'not_types') - - # room checks - if room_id is not None: - allow_rooms = definition.get("rooms", None) - reject_rooms = definition.get("not_rooms", None) - if reject_rooms and room_id in reject_rooms: - return False - if allow_rooms and room_id not in allow_rooms: - return False +class Filter(object): + def __init__(self, filter_json): + self.filter_json = filter_json - # sender checks - if sender is not None: - allow_senders = definition.get("senders", None) - reject_senders = definition.get("not_senders", None) - if reject_senders and sender in reject_senders: - return False - if allow_senders and sender not in allow_senders: + def check(self, event): + literal_keys = { + "rooms": lambda v: event.room_id == v, + "senders": lambda v: event.sender == v, + "types": lambda v: _matches_wildcard(event.type, v) + } + + for name, match_func in literal_keys.items(): + not_name = "not_%s" % (name,) + disallowed_values = self.filter_json.get(not_name, []) + if any(map(match_func, disallowed_values)): return False - # type checks - if "not_types" in definition: - for def_type in definition["not_types"]: - if self._event_matches_type(event_type, def_type): + allowed_values = self.filter_json.get(name, None) + if allowed_values is not None: + if not any(map(match_func, allowed_values)): return False - if "types" in definition: - included = False - for def_type in definition["types"]: - if self._event_matches_type(event_type, def_type): - included = True - break - if not included: - return False return True - def _event_matches_type(self, event_type, def_type): - if def_type.endswith("*"): - type_prefix = def_type[:-1] - return event_type.startswith(type_prefix) - else: - return event_type == def_type + def filter(self, events): + return filter(self.check, events) + + def limit(self): + return self.filter_json.get("limit", 10) + + +def _matches_wildcard(actual_value, filter_value): + if filter_value.endswith("*"): + type_prefix = filter_value[:-1] + return actual_value.startswith(type_prefix) + else: + return actual_value == filter_value diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index fffecb24f5..5e27a859f9 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -23,7 +23,7 @@ from synapse.types import StreamToken from synapse.events.utils import ( serialize_event, format_event_for_client_v2_without_event_id, ) -from synapse.api.filtering import Filter +from synapse.api.filtering import FilterCollection from ._base import client_v2_pattern import copy @@ -103,7 +103,7 @@ class SyncRestServlet(RestServlet): user.localpart, filter_id ) except: - filter = Filter({}) + filter = FilterCollection({}) sync_config = SyncConfig( user=user, diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 6942cdac51..9f9af2d783 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -23,10 +23,17 @@ from tests.utils import ( ) from synapse.types import UserID -from synapse.api.filtering import Filter +from synapse.api.filtering import FilterCollection, Filter user_localpart = "test_user" -MockEvent = namedtuple("MockEvent", "sender type room_id") +# MockEvent = namedtuple("MockEvent", "sender type room_id") + + +def MockEvent(**kwargs): + ev = NonCallableMock(spec_set=kwargs.keys()) + ev.configure_mock(**kwargs) + return ev + class FilteringTestCase(unittest.TestCase): @@ -44,7 +51,6 @@ class FilteringTestCase(unittest.TestCase): ) self.filtering = hs.get_filtering() - self.filter = Filter({}) self.datastore = hs.get_datastore() @@ -57,8 +63,9 @@ class FilteringTestCase(unittest.TestCase): type="m.room.message", room_id="!foo:bar" ) + self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_types_works_with_wildcards(self): @@ -71,7 +78,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_types_works_with_unknowns(self): @@ -84,7 +91,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_types_works_with_literals(self): @@ -97,7 +104,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_types_works_with_wildcards(self): @@ -110,7 +117,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_types_works_with_unknowns(self): @@ -123,7 +130,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_types_takes_priority_over_types(self): @@ -137,7 +144,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_senders_works_with_literals(self): @@ -150,7 +157,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_senders_works_with_unknowns(self): @@ -163,7 +170,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_senders_works_with_literals(self): @@ -176,7 +183,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_senders_works_with_unknowns(self): @@ -189,7 +196,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_senders_takes_priority_over_senders(self): @@ -203,7 +210,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!foo:bar" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_rooms_works_with_literals(self): @@ -216,7 +223,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!secretbase:unknown" ) self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_rooms_works_with_unknowns(self): @@ -229,7 +236,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!anothersecretbase:unknown" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_rooms_works_with_literals(self): @@ -242,7 +249,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!anothersecretbase:unknown" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_rooms_works_with_unknowns(self): @@ -255,7 +262,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!anothersecretbase:unknown" ) self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_not_rooms_takes_priority_over_rooms(self): @@ -269,7 +276,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!secretbase:unknown" ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_combined_event(self): @@ -287,7 +294,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!stage:unknown" # yup ) self.assertTrue( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_combined_event_bad_sender(self): @@ -305,7 +312,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!stage:unknown" # yup ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_combined_event_bad_room(self): @@ -323,7 +330,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!piggyshouse:muppets" # nope ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) def test_definition_combined_event_bad_type(self): @@ -341,7 +348,7 @@ class FilteringTestCase(unittest.TestCase): room_id="!stage:unknown" # yup ) self.assertFalse( - self.filter._passes_definition(definition, event) + Filter(definition).check(event) ) @defer.inlineCallbacks @@ -359,7 +366,6 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent( sender="@foo:bar", type="m.profile", - room_id="!foo:bar" ) events = [event] @@ -386,7 +392,6 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent( sender="@foo:bar", type="custom.avatar.3d.crazy", - room_id="!foo:bar" ) events = [event] -- cgit 1.5.1 From f69a5c9134a3e4bba929dc76d561d9cc42cadeac Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 26 Oct 2015 18:32:49 +0000 Subject: Fix a 500 error resulting from empty room_ids POST /_matrix/client/api/v1/rooms//send/a.b.c gave a 500 error, because we assumed that rooms always had at least one character. --- synapse/types.py | 2 +- tests/test_types.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/synapse/types.py b/synapse/types.py index 9cffc33d27..8c51e00e8a 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -47,7 +47,7 @@ class DomainSpecificString( @classmethod def from_string(cls, s): """Parse the string given by 's' into a structure object.""" - if s[0] != cls.SIGIL: + if len(s) < 1 or s[0] != cls.SIGIL: raise SynapseError(400, "Expected %s string to start with '%s'" % ( cls.__name__, cls.SIGIL, )) diff --git a/tests/test_types.py b/tests/test_types.py index b29a8415b1..495cd20f02 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -15,13 +15,14 @@ from tests import unittest +from synapse.api.errors import SynapseError from synapse.server import BaseHomeServer from synapse.types import UserID, RoomAlias mock_homeserver = BaseHomeServer(hostname="my.domain") -class UserIDTestCase(unittest.TestCase): +class UserIDTestCase(unittest.TestCase): def test_parse(self): user = UserID.from_string("@1234abcd:my.domain") @@ -29,6 +30,11 @@ class UserIDTestCase(unittest.TestCase): self.assertEquals("my.domain", user.domain) self.assertEquals(True, mock_homeserver.is_mine(user)) + def test_pase_empty(self): + with self.assertRaises(SynapseError): + UserID.from_string("") + + def test_build(self): user = UserID("5678efgh", "my.domain") @@ -44,7 +50,6 @@ class UserIDTestCase(unittest.TestCase): class RoomAliasTestCase(unittest.TestCase): - def test_parse(self): room = RoomAlias.from_string("#channel:my.domain") -- cgit 1.5.1 From fb46937413cc0ccbf12063a5743ddf914cd8170a Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 30 Oct 2015 16:38:35 +0000 Subject: Support clients supplying older tokens, fix short poll test --- synapse/types.py | 2 +- tests/rest/client/v1/test_presence.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/synapse/types.py b/synapse/types.py index 8d3a8d88cc..84631d177d 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -112,7 +112,7 @@ class StreamToken( def from_string(cls, string): try: keys = string.split(cls._SEPARATOR) - if len(keys) == len(cls._fields) - 1: + while len(keys) < len(cls._fields): # i.e. old token from before receipt_key keys.append("0") return cls(*keys) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 29d9bbaad4..0e3b922246 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -369,7 +369,7 @@ class PresenceEventStreamTestCase(unittest.TestCase): # all be ours # I'll already get my own presence state change - self.assertEquals({"start": "0_1_0_0", "end": "0_1_0_0", "chunk": []}, + self.assertEquals({"start": "0_1_0_0_0", "end": "0_1_0_0_0", "chunk": []}, response ) @@ -388,7 +388,7 @@ class PresenceEventStreamTestCase(unittest.TestCase): "/events?from=s0_1_0&timeout=0", None) self.assertEquals(200, code) - self.assertEquals({"start": "s0_1_0_0", "end": "s0_2_0_0", "chunk": [ + self.assertEquals({"start": "s0_1_0_0_0", "end": "s0_2_0_0_0", "chunk": [ {"type": "m.presence", "content": { "user_id": "@banana:test", -- cgit 1.5.1 From 771ca56c886dd08f707447cfff70acd3ba73e98c Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Mon, 2 Nov 2015 15:31:57 +0000 Subject: Remove more unused parameters --- synapse/handlers/room.py | 1 - synapse/handlers/sync.py | 1 - synapse/storage/stream.py | 3 +-- tests/storage/test_redaction.py | 4 ---- tests/storage/test_stream.py | 4 ---- tests/utils.py | 2 +- 6 files changed, 2 insertions(+), 13 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 36878a6c20..9184dcd048 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -827,7 +827,6 @@ class RoomEventSource(object): user_id=user.to_string(), from_key=from_key, to_key=to_key, - room_id=None, limit=limit, ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index eaa14f38df..4054efe555 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -342,7 +342,6 @@ class SyncHandler(BaseHandler): sync_config.user.to_string(), from_key=since_token.room_key, to_key=now_token.room_key, - room_id=None, limit=timeline_limit + 1, ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 15d4c2bf68..c728013f4c 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -158,8 +158,7 @@ class StreamStore(SQLBaseStore): defer.returnValue(results) @log_function - def get_room_events_stream(self, user_id, from_key, to_key, room_id, - limit=0): + def get_room_events_stream(self, user_id, from_key, to_key, limit=0): current_room_membership_sql = ( "SELECT m.room_id FROM room_memberships as m " " INNER JOIN current_state_events as c" diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index b57006fcb4..dbf9700e6a 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -120,7 +120,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -149,7 +148,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -199,7 +197,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -228,7 +225,6 @@ class RedactionTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index a658a789aa..e5c2c5cc8e 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -68,7 +68,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_bob.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -105,7 +104,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_alice.to_string(), start, end, - None, # Is currently ignored ) self.assertEqual(1, len(results)) @@ -147,7 +145,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_bob.to_string(), start, end, - None, # Is currently ignored ) # We should not get the message, as it happened *after* bob left. @@ -175,7 +172,6 @@ class StreamStoreTestCase(unittest.TestCase): self.u_bob.to_string(), start, end, - None, # Is currently ignored ) # We should not get the message, as it happened *after* bob left. diff --git a/tests/utils.py b/tests/utils.py index 4da51291a4..ca2c33cf8e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -335,7 +335,7 @@ class MemoryDataStore(object): ] def get_room_events_stream(self, user_id=None, from_key=None, to_key=None, - room_id=None, limit=0, with_feedback=False): + limit=0, with_feedback=False): return ([], from_key) # TODO def get_joined_hosts_for_room(self, room_id): -- cgit 1.5.1 From c452dabc3d295998ed70dfa977866568dce9fa79 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 4 Nov 2015 14:08:15 +0000 Subject: Remove the LockManager class because it wasn't being used --- synapse/handlers/federation.py | 2 - synapse/server.py | 5 -- synapse/util/lockutils.py | 74 ---------------------------- tests/util/test_lock.py | 108 ----------------------------------------- 4 files changed, 189 deletions(-) delete mode 100644 synapse/util/lockutils.py delete mode 100644 tests/util/test_lock.py (limited to 'tests') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ae9d227586..b2395b28d1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -72,8 +72,6 @@ class FederationHandler(BaseHandler): self.server_name = hs.hostname self.keyring = hs.get_keyring() - self.lock_manager = hs.get_room_lock_manager() - self.replication_layer.set_handler(self) # When joining a room we need to queue any events for that room up diff --git a/synapse/server.py b/synapse/server.py index 8424798b1b..f75d5358b2 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -29,7 +29,6 @@ from synapse.state import StateHandler from synapse.storage import DataStore from synapse.util import Clock from synapse.util.distributor import Distributor -from synapse.util.lockutils import LockManager from synapse.streams.events import EventSources from synapse.api.ratelimiting import Ratelimiter from synapse.crypto.keyring import Keyring @@ -70,7 +69,6 @@ class BaseHomeServer(object): 'auth', 'rest_servlet_factory', 'state_handler', - 'room_lock_manager', 'notifier', 'distributor', 'resource_for_client', @@ -201,9 +199,6 @@ class HomeServer(BaseHomeServer): def build_state_handler(self): return StateHandler(self) - def build_room_lock_manager(self): - return LockManager() - def build_distributor(self): return Distributor() diff --git a/synapse/util/lockutils.py b/synapse/util/lockutils.py deleted file mode 100644 index 33edc5c20e..0000000000 --- a/synapse/util/lockutils.py +++ /dev/null @@ -1,74 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer - -import logging - - -logger = logging.getLogger(__name__) - - -class Lock(object): - - def __init__(self, deferred, key): - self._deferred = deferred - self.released = False - self.key = key - - def release(self): - self.released = True - self._deferred.callback(None) - - def __del__(self): - if not self.released: - logger.critical("Lock was destructed but never released!") - self.release() - - def __enter__(self): - return self - - def __exit__(self, type, value, traceback): - logger.debug("Releasing lock for key=%r", self.key) - self.release() - - -class LockManager(object): - """ Utility class that allows us to lock based on a `key` """ - - def __init__(self): - self._lock_deferreds = {} - - @defer.inlineCallbacks - def lock(self, key): - """ Allows us to block until it is our turn. - Args: - key (str) - Returns: - Lock - """ - new_deferred = defer.Deferred() - old_deferred = self._lock_deferreds.get(key) - self._lock_deferreds[key] = new_deferred - - if old_deferred: - logger.debug("Queueing on lock for key=%r", key) - yield old_deferred - logger.debug("Obtained lock for key=%r", key) - else: - logger.debug("Entering uncontended lock for key=%r", key) - - defer.returnValue(Lock(new_deferred, key)) diff --git a/tests/util/test_lock.py b/tests/util/test_lock.py deleted file mode 100644 index 6a1e521b1e..0000000000 --- a/tests/util/test_lock.py +++ /dev/null @@ -1,108 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer -from tests import unittest - -from synapse.util.lockutils import LockManager - - -class LockManagerTestCase(unittest.TestCase): - - def setUp(self): - self.lock_manager = LockManager() - - @defer.inlineCallbacks - def test_one_lock(self): - key = "test" - deferred_lock1 = self.lock_manager.lock(key) - - self.assertTrue(deferred_lock1.called) - - lock1 = yield deferred_lock1 - - self.assertFalse(lock1.released) - - lock1.release() - - self.assertTrue(lock1.released) - - @defer.inlineCallbacks - def test_concurrent_locks(self): - key = "test" - deferred_lock1 = self.lock_manager.lock(key) - deferred_lock2 = self.lock_manager.lock(key) - - self.assertTrue(deferred_lock1.called) - self.assertFalse(deferred_lock2.called) - - lock1 = yield deferred_lock1 - - self.assertFalse(lock1.released) - self.assertFalse(deferred_lock2.called) - - lock1.release() - - self.assertTrue(lock1.released) - self.assertTrue(deferred_lock2.called) - - lock2 = yield deferred_lock2 - - lock2.release() - - @defer.inlineCallbacks - def test_sequential_locks(self): - key = "test" - deferred_lock1 = self.lock_manager.lock(key) - - self.assertTrue(deferred_lock1.called) - - lock1 = yield deferred_lock1 - - self.assertFalse(lock1.released) - - lock1.release() - - self.assertTrue(lock1.released) - - deferred_lock2 = self.lock_manager.lock(key) - - self.assertTrue(deferred_lock2.called) - - lock2 = yield deferred_lock2 - - self.assertFalse(lock2.released) - - lock2.release() - - self.assertTrue(lock2.released) - - @defer.inlineCallbacks - def test_with_statement(self): - key = "test" - with (yield self.lock_manager.lock(key)) as lock: - self.assertFalse(lock.released) - - self.assertTrue(lock.released) - - @defer.inlineCallbacks - def test_two_with_statement(self): - key = "test" - with (yield self.lock_manager.lock(key)): - pass - - with (yield self.lock_manager.lock(key)): - pass -- cgit 1.5.1 From f522f50a08d48042d103c98dbc3cfd4872b7d981 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Wed, 4 Nov 2015 17:29:07 +0000 Subject: Allow guests to register and call /events?room_id= This follows the same flows-based flow as regular registration, but as the only implemented flow has no requirements, it auto-succeeds. In the future, other flows (e.g. captcha) may be required, so clients should treat this like the regular registration flow choices. --- synapse/api/auth.py | 95 ++++++++++++++++------------- synapse/api/errors.py | 1 + synapse/config/registration.py | 6 ++ synapse/handlers/_base.py | 75 ++++++++++++++--------- synapse/handlers/auth.py | 5 +- synapse/handlers/message.py | 46 +++++++------- synapse/handlers/register.py | 12 ++-- synapse/rest/client/v1/admin.py | 2 +- synapse/rest/client/v1/directory.py | 4 +- synapse/rest/client/v1/events.py | 4 +- synapse/rest/client/v1/initial_sync.py | 2 +- synapse/rest/client/v1/presence.py | 8 +-- synapse/rest/client/v1/profile.py | 4 +- synapse/rest/client/v1/push_rule.py | 6 +- synapse/rest/client/v1/pusher.py | 2 +- synapse/rest/client/v1/room.py | 27 ++++---- synapse/rest/client/v1/voip.py | 2 +- synapse/rest/client/v2_alpha/account.py | 6 +- synapse/rest/client/v2_alpha/filter.py | 4 +- synapse/rest/client/v2_alpha/keys.py | 6 +- synapse/rest/client/v2_alpha/receipts.py | 2 +- synapse/rest/client/v2_alpha/register.py | 27 +++++++- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/rest/client/v2_alpha/tags.py | 6 +- synapse/rest/media/v0/content_repository.py | 2 +- synapse/rest/media/v1/upload_resource.py | 2 +- synapse/storage/registration.py | 15 ++--- tests/api/test_auth.py | 25 +++++++- tests/rest/client/v1/test_presence.py | 10 +-- tests/rest/client/v1/test_profile.py | 4 +- tests/rest/client/v1/test_rooms.py | 21 ++++--- tests/rest/client/v1/test_typing.py | 3 +- tests/rest/client/v2_alpha/__init__.py | 3 +- 33 files changed, 272 insertions(+), 167 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 88445fe999..dfbbc5a1cd 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -49,6 +49,7 @@ class Auth(object): self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self._KNOWN_CAVEAT_PREFIXES = set([ "gen = ", + "guest = ", "type = ", "time < ", "user_id = ", @@ -183,15 +184,11 @@ class Auth(object): defer.returnValue(member) @defer.inlineCallbacks - def check_user_was_in_room(self, room_id, user_id, current_state=None): + def check_user_was_in_room(self, room_id, user_id): """Check if the user was in the room at some point. Args: room_id(str): The room to check. user_id(str): The user to check. - current_state(dict): Optional map of the current state of the room. - If provided then that map is used to check whether they are a - member of the room. Otherwise the current membership is - loaded from the database. Raises: AuthError if the user was never in the room. Returns: @@ -199,17 +196,11 @@ class Auth(object): room. This will be the join event if they are currently joined to the room. This will be the leave event if they have left the room. """ - if current_state: - member = current_state.get( - (EventTypes.Member, user_id), - None - ) - else: - member = yield self.state.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id - ) + member = yield self.state.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) membership = member.membership if member else None if membership not in (Membership.JOIN, Membership.LEAVE): @@ -497,7 +488,7 @@ class Auth(object): return default @defer.inlineCallbacks - def get_user_by_req(self, request): + def get_user_by_req(self, request, allow_guest=False): """ Get a registered user's ID. Args: @@ -535,7 +526,7 @@ class Auth(object): request.authenticated_entity = user_id - defer.returnValue((UserID.from_string(user_id), "")) + defer.returnValue((UserID.from_string(user_id), "", False)) return except KeyError: pass # normal users won't have the user_id query parameter set. @@ -543,6 +534,7 @@ class Auth(object): user_info = yield self._get_user_by_access_token(access_token) user = user_info["user"] token_id = user_info["token_id"] + is_guest = user_info["is_guest"] ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( @@ -557,9 +549,14 @@ class Auth(object): user_agent=user_agent ) + if is_guest and not allow_guest: + raise AuthError( + 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + ) + request.authenticated_entity = user.to_string() - defer.returnValue((user, token_id,)) + defer.returnValue((user, token_id, is_guest,)) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", @@ -592,31 +589,45 @@ class Auth(object): self._validate_macaroon(macaroon) user_prefix = "user_id = " + user = None + guest = False for caveat in macaroon.caveats: if caveat.caveat_id.startswith(user_prefix): user = UserID.from_string(caveat.caveat_id[len(user_prefix):]) - # This codepath exists so that we can actually return a - # token ID, because we use token IDs in place of device - # identifiers throughout the codebase. - # TODO(daniel): Remove this fallback when device IDs are - # properly implemented. - ret = yield self._look_up_user_by_access_token(macaroon_str) - if ret["user"] != user: - logger.error( - "Macaroon user (%s) != DB user (%s)", - user, - ret["user"] - ) - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, - "User mismatch in macaroon", - errcode=Codes.UNKNOWN_TOKEN - ) - defer.returnValue(ret) - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", - errcode=Codes.UNKNOWN_TOKEN - ) + elif caveat.caveat_id == "guest = true": + guest = True + + if user is None: + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + + if guest: + ret = { + "user": user, + "is_guest": True, + "token_id": None, + } + else: + # This codepath exists so that we can actually return a + # token ID, because we use token IDs in place of device + # identifiers throughout the codebase. + # TODO(daniel): Remove this fallback when device IDs are + # properly implemented. + ret = yield self._look_up_user_by_access_token(macaroon_str) + if ret["user"] != user: + logger.error( + "Macaroon user (%s) != DB user (%s)", + user, + ret["user"] + ) + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "User mismatch in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + defer.returnValue(ret) except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", @@ -629,6 +640,7 @@ class Auth(object): v.satisfy_exact("type = access") v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(self._verify_expiry) + v.satisfy_exact("guest = true") v.verify(macaroon, self.hs.config.macaroon_secret_key) v = pymacaroons.Verifier() @@ -666,6 +678,7 @@ class Auth(object): user_info = { "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), + "is_guest": False, } defer.returnValue(user_info) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b3fea27d0e..d4037b3d55 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -33,6 +33,7 @@ class Codes(object): NOT_FOUND = "M_NOT_FOUND" MISSING_TOKEN = "M_MISSING_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" + GUEST_ACCESS_FORBIDDEN = "M_GUEST_ACCESS_FORBIDDEN" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_INVALID = "M_CAPTCHA_INVALID" diff --git a/synapse/config/registration.py b/synapse/config/registration.py index f5ef36a9f4..dca391f7af 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -34,6 +34,7 @@ class RegistrationConfig(Config): self.registration_shared_secret = config.get("registration_shared_secret") self.macaroon_secret_key = config.get("macaroon_secret_key") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) + self.allow_guest_access = config.get("allow_guest_access", False) def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) @@ -54,6 +55,11 @@ class RegistrationConfig(Config): # Larger numbers increase the work factor needed to generate the hash. # The default number of rounds is 12. bcrypt_rounds: 12 + + # Allows users to register as guests without a password/email/etc, and + # participate in rooms hosted on this server which have been made + # accessible to anonymous users. + allow_guest_access: False """ % locals() def add_arguments(self, parser): diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 6a26cb1879..6873a4575d 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -47,37 +47,23 @@ class BaseHandler(object): self.event_builder_factory = hs.get_event_builder_factory() @defer.inlineCallbacks - def _filter_events_for_client(self, user_id, events): - event_id_to_state = yield self.store.get_state_for_events( - frozenset(e.event_id for e in events), - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, user_id), - ) - ) + def _filter_events_for_client(self, user_id, events, is_guest=False): + # Assumes that user has at some point joined the room if not is_guest. - def allowed(event, state): - if event.type == EventTypes.RoomHistoryVisibility: + def allowed(event, membership, visibility): + if visibility == "world_readable": return True - membership_ev = state.get((EventTypes.Member, user_id), None) - if membership_ev: - membership = membership_ev.membership - else: - membership = Membership.LEAVE + if is_guest: + return False if membership == Membership.JOIN: return True - history = state.get((EventTypes.RoomHistoryVisibility, ''), None) - if history: - visibility = history.content.get("history_visibility", "shared") - else: - visibility = "shared" + if event.type == EventTypes.RoomHistoryVisibility: + return not is_guest - if visibility == "public": - return True - elif visibility == "shared": + if visibility == "shared": return True elif visibility == "joined": return membership == Membership.JOIN @@ -86,11 +72,44 @@ class BaseHandler(object): return True - defer.returnValue([ - event - for event in events - if allowed(event, event_id_to_state[event.event_id]) - ]) + event_id_to_state = yield self.store.get_state_for_events( + frozenset(e.event_id for e in events), + types=( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id), + ) + ) + + events_to_return = [] + for event in events: + state = event_id_to_state[event.event_id] + + membership_event = state.get((EventTypes.Member, user_id), None) + if membership_event: + membership = membership_event.membership + else: + membership = None + + visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) + if visibility_event: + visibility = visibility_event.content.get("history_visibility", "shared") + else: + visibility = "shared" + + should_include = allowed(event, membership, visibility) + if should_include: + events_to_return.append(event) + + if is_guest and len(events_to_return) < len(events): + # This indicates that some events in the requested range were not + # visible to guest users. To be safe, we reject the entire request, + # so that we don't have to worry about interpreting visibility + # boundaries. + raise AuthError(403, "User %s does not have permission" % ( + user_id + )) + + defer.returnValue(events_to_return) def ratelimit(self, user_id): time_now = self.clock.time() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 055d395b20..1b11dbdffd 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -372,12 +372,15 @@ class AuthHandler(BaseHandler): yield self.store.add_refresh_token_to_user(user_id, refresh_token) defer.returnValue(refresh_token) - def generate_access_token(self, user_id): + def generate_access_token(self, user_id, extra_caveats=None): + extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") now = self.hs.get_clock().time_msec() expiry = now + (60 * 60 * 1000) macaroon.add_first_party_caveat("time < %d" % (expiry,)) + for caveat in extra_caveats: + macaroon.add_first_party_caveat(caveat) return macaroon.serialize() def generate_refresh_token(self, user_id): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0f947993d1..687e1527f7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -71,20 +71,20 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_messages(self, user_id=None, room_id=None, pagin_config=None, - as_client_event=True): + as_client_event=True, is_guest=False): """Get messages in a room. Args: user_id (str): The user requesting messages. room_id (str): The room they want messages from. pagin_config (synapse.api.streams.PaginationConfig): The pagination - config rules to apply, if any. + config rules to apply, if any. as_client_event (bool): True to get events in client-server format. + is_guest (bool): Whether the requesting user is a guest (as opposed + to a fully registered user). Returns: dict: Pagination API results """ - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) - data_source = self.hs.get_event_sources().sources["room"] if pagin_config.from_token: @@ -107,23 +107,27 @@ class MessageHandler(BaseHandler): source_config = pagin_config.get_source_config("room") - if member_event.membership == Membership.LEAVE: - # If they have left the room then clamp the token to be before - # they left the room - leave_token = yield self.store.get_topological_token_for_event( - member_event.event_id - ) - leave_token = RoomStreamToken.parse(leave_token) - if leave_token.topological < room_token.topological: - source_config.from_key = str(leave_token) - - if source_config.direction == "f": - if source_config.to_key is None: - source_config.to_key = str(leave_token) - else: - to_token = RoomStreamToken.parse(source_config.to_key) - if leave_token.topological < to_token.topological: + if not is_guest: + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + if member_event.membership == Membership.LEAVE: + # If they have left the room then clamp the token to be before + # they left the room. + # If they're a guest, we'll just 403 them if they're asking for + # events they can't see. + leave_token = yield self.store.get_topological_token_for_event( + member_event.event_id + ) + leave_token = RoomStreamToken.parse(leave_token) + if leave_token.topological < room_token.topological: + source_config.from_key = str(leave_token) + + if source_config.direction == "f": + if source_config.to_key is None: source_config.to_key = str(leave_token) + else: + to_token = RoomStreamToken.parse(source_config.to_key) + if leave_token.topological < to_token.topological: + source_config.to_key = str(leave_token) yield self.hs.get_handlers().federation_handler.maybe_backfill( room_id, room_token.topological @@ -146,7 +150,7 @@ class MessageHandler(BaseHandler): "end": next_token.to_string(), }) - events = yield self._filter_events_for_client(user_id, events) + events = yield self._filter_events_for_client(user_id, events, is_guest=is_guest) time_now = self.clock.time_msec() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ef4081e3fe..493a087031 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -64,7 +64,7 @@ class RegistrationHandler(BaseHandler): ) @defer.inlineCallbacks - def register(self, localpart=None, password=None): + def register(self, localpart=None, password=None, generate_token=True): """Registers a new client on the server. Args: @@ -89,7 +89,9 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.auth_handler().generate_access_token(user_id) + token = None + if generate_token: + token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -102,14 +104,14 @@ class RegistrationHandler(BaseHandler): attempts = 0 user_id = None token = None - while not user_id and not token: + while not user_id: try: localpart = self._generate_user_id() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() yield self.check_user_id_is_valid(user_id) - - token = self.auth_handler().generate_access_token(user_id) + if generate_token: + token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 504b63eab4..bdde43864c 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(auth_user) if not is_admin and target_user != auth_user: diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 4dcda57c1b..240eedac75 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -69,7 +69,7 @@ class ClientDirectoryServer(ClientV1RestServlet): try: # try to auth as a user - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) try: user_id = user.to_string() yield dir_handler.create_association( @@ -116,7 +116,7 @@ class ClientDirectoryServer(ClientV1RestServlet): # fallback to default user behaviour if they aren't an AS pass - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(user) if not is_admin: diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 582148b659..4073b0d2d1 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) try: handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) @@ -71,7 +71,7 @@ class EventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, event_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.event_handler event = yield handler.get_event(auth_user, event_id) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 52c7943400..856a70f297 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) handler = self.handlers.message_handler diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index a770efd841..6fe5d19a22 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = yield self.handlers.presence_handler.get_state( @@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = {} @@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): @@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index fdde88a60d..6b379e4e5f 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: @@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index bd759a2589..b0870db1ac 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -43,7 +43,7 @@ class PushRuleRestServlet(ClientV1RestServlet): except InvalidRuleException as e: raise SynapseError(400, e.message) - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) if '/' in spec['rule_id'] or '\\' in spec['rule_id']: raise SynapseError(400, "rule_id may not contain slashes") @@ -92,7 +92,7 @@ class PushRuleRestServlet(ClientV1RestServlet): def on_DELETE(self, request): spec = _rule_spec_from_path(request.postpath) - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) namespaced_rule_id = _namespaced_rule_id_from_spec(spec) @@ -109,7 +109,7 @@ class PushRuleRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 3aabc93b8b..a110c0a4f0 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 2dcaee86cd..0876e593c5 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) room_config = self.get_room_config(request) info = yield self.make_room(room_config, auth_user, None) @@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( @@ -143,7 +143,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -175,7 +175,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_type, txn_id=None): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -220,7 +220,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) # the identifier could be a room alias or a room id. Try one then the # other if it fails to parse, without swallowing other valid @@ -289,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.message_handler events = yield handler.get_state_events( room_id=room_id, @@ -325,7 +325,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, _ = yield self.auth.get_user_by_req(request) + user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request( request, default_limit=10, ) @@ -334,6 +334,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): msgs = yield handler.get_messages( room_id=room_id, user_id=user.to_string(), + is_guest=is_guest, pagin_config=pagination_config, as_client_event=as_client_event ) @@ -347,7 +348,7 @@ class RoomStateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) handler = self.handlers.message_handler # Get all the current state for this room events = yield handler.get_state_events( @@ -363,7 +364,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request(request) content = yield self.handlers.message_handler.room_initial_sync( room_id=room_id, @@ -443,7 +444,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, membership_action, txn_id=None): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -524,7 +525,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, event_id, txn_id=None): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -564,7 +565,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) room_id = urllib.unquote(room_id) target_user = UserID.from_string(urllib.unquote(user_id)) @@ -597,7 +598,7 @@ class SearchRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) content = _parse_json(request) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 0a863e1c61..eb7c57cade 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 4692ba413c..1970ad3458 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -55,7 +55,7 @@ class PasswordRestServlet(RestServlet): if LoginType.PASSWORD in result: # if using password, they should also be logged in - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if auth_user.to_string() != result[LoginType.PASSWORD]: raise LoginError(400, "", Codes.UNKNOWN) user_id = auth_user.to_string() @@ -102,7 +102,7 @@ class ThreepidRestServlet(RestServlet): def on_GET(self, request): yield run_on_reactor() - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) threepids = yield self.hs.get_datastore().user_get_threepids( auth_user.to_string() @@ -120,7 +120,7 @@ class ThreepidRestServlet(RestServlet): raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) threePidCreds = body['threePidCreds'] - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index f8f91b63f5..97956a4b91 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -40,7 +40,7 @@ class GetFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, filter_id): target_user = UserID.from_string(user_id) - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if target_user != auth_user: raise AuthError(403, "Cannot get filters for other users") @@ -76,7 +76,7 @@ class CreateFilterRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if target_user != auth_user: raise AuthError(403, "Cannot create filters for other users") diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index a1f4423101..820d33336f 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, device_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() # TODO: Check that the device_id matches that in the authentication # or derive the device_id from the authentication instead. @@ -109,7 +109,7 @@ class KeyUploadServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, device_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) user_id = auth_user.to_string() result = yield self.store.count_e2e_one_time_keys(user_id, device_id) @@ -181,7 +181,7 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) auth_user_id = auth_user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index b107b7ce17..788acd4adb 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -40,7 +40,7 @@ class ReceiptRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, room_id, receipt_type, event_id): - user, _ = yield self.auth.get_user_by_req(request) + user, _, _ = yield self.auth.get_user_by_req(request) if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 1ba2f29711..f899376311 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.constants import LoginType -from synapse.api.errors import SynapseError, Codes +from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import RestServlet from ._base import client_v2_pattern, parse_json_dict_from_request @@ -55,6 +55,19 @@ class RegisterRestServlet(RestServlet): def on_POST(self, request): yield run_on_reactor() + kind = "user" + if "kind" in request.args: + kind = request.args["kind"][0] + + if kind == "guest": + ret = yield self._do_guest_registration() + defer.returnValue(ret) + return + elif kind != "user": + raise UnrecognizedRequestError( + "Do not understand membership kind: %s" % (kind,) + ) + if '/register/email/requestToken' in request.path: ret = yield self.onEmailTokenRequest(request) defer.returnValue(ret) @@ -236,6 +249,18 @@ class RegisterRestServlet(RestServlet): ret = yield self.identity_handler.requestEmailToken(**body) defer.returnValue((200, ret)) + @defer.inlineCallbacks + def _do_guest_registration(self): + if not self.hs.config.allow_guest_access: + defer.returnValue((403, "Guest access is disabled")) + user_id, _ = yield self.registration_handler.register(generate_token=False) + access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"]) + defer.returnValue((200, { + "user_id": user_id, + "access_token": access_token, + "home_server": self.hs.hostname, + })) + def register_servlets(hs, http_server): RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 32a1087c91..d24507effa 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -81,7 +81,7 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user, token_id = yield self.auth.get_user_by_req(request) + user, token_id, _ = yield self.auth.get_user_by_req(request) timeout = parse_integer(request, "timeout", default=0) since = parse_string(request, "since") diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index dcfe6bd20e..35482ae6a6 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -42,7 +42,7 @@ class TagListServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id, room_id): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if user_id != auth_user.to_string(): raise AuthError(403, "Cannot get tags for other users.") @@ -68,7 +68,7 @@ class TagServlet(RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id, room_id, tag): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if user_id != auth_user.to_string(): raise AuthError(403, "Cannot add tags for other users.") @@ -88,7 +88,7 @@ class TagServlet(RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, user_id, room_id, tag): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) if user_id != auth_user.to_string(): raise AuthError(403, "Cannot add tags for other users.") diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index c28dc86cd7..e4fa8c4647 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource): @defer.inlineCallbacks def map_request_to_name(self, request): # auth the user - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) # namespace all file uploads on the user prefix = base64.urlsafe_b64encode( diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 6abaf56b25..7d61596082 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource): @request_handler @defer.inlineCallbacks def _async_render_POST(self, request): - auth_user, _ = yield self.auth.get_user_by_req(request) + auth_user, _, _ = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point content_length = request.getHeader("Content-Length") diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index b454dd5b3a..2e5eddd259 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -102,13 +102,14 @@ class RegistrationStore(SQLBaseStore): 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) - # it's possible for this to get a conflict, but only for a single user - # since tokens are namespaced based on their user ID - txn.execute( - "INSERT INTO access_tokens(id, user_id, token)" - " VALUES (?,?,?)", - (next_id, user_id, token,) - ) + if token: + # it's possible for this to get a conflict, but only for a single user + # since tokens are namespaced based on their user ID + txn.execute( + "INSERT INTO access_tokens(id, user_id, token)" + " VALUES (?,?,?)", + (next_id, user_id, token,) + ) def get_user_by_id(self, user_id): return self._simple_select_one( diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index c96273480d..70d928defe 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, _) = yield self.auth.get_user_by_req(request) + (user, _, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): @@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, _) = yield self.auth.get_user_by_req(request) + (user, _, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), self.test_user) def test_get_user_by_req_appservice_bad_token(self): @@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase): request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = Mock(return_value=[""]) - (user, _) = yield self.auth.get_user_by_req(request) + (user, _, _) = yield self.auth.get_user_by_req(request) self.assertEquals(user.to_string(), masquerading_user_id) def test_get_user_by_req_appservice_valid_token_bad_user_id(self): @@ -158,6 +158,25 @@ class AuthTestCase(unittest.TestCase): user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) + @defer.inlineCallbacks + def test_get_guest_user_from_macaroon(self): + user_id = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + macaroon.add_first_party_caveat("guest = true") + serialized = macaroon.serialize() + + user_info = yield self.auth._get_user_from_macaroon(serialized) + user = user_info["user"] + is_guest = user_info["is_guest"] + self.assertEqual(UserID.from_string(user_id), user) + self.assertTrue(is_guest) + @defer.inlineCallbacks def test_get_user_from_macaroon_user_db_mismatch(self): self.store.get_user_by_access_token = Mock( diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 0e3b922246..3e0f294630 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -86,10 +86,11 @@ class PresenceStateTestCase(unittest.TestCase): return defer.succeed([]) self.datastore.get_presence_list = get_presence_list - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(myid), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -173,10 +174,11 @@ class PresenceListTestCase(unittest.TestCase): ) self.datastore.has_presence_state = has_presence_state - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(myid), "token_id": 1, + "is_guest": False, } hs.handlers.room_member_handler = Mock( @@ -291,8 +293,8 @@ class PresenceEventStreamTestCase(unittest.TestCase): hs.get_clock().time_msec.return_value = 1000000 - def _get_user_by_req(req=None): - return (UserID.from_string(myid), "") + def _get_user_by_req(req=None, allow_guest=False): + return (UserID.from_string(myid), "", False) hs.get_v1auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 929e5e5dd4..adcc1d1969 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -52,8 +52,8 @@ class ProfileTestCase(unittest.TestCase): replication_layer=Mock(), ) - def _get_user_by_req(request=None): - return (UserID.from_string(myid), "") + def _get_user_by_req(request=None, allow_guest=False): + return (UserID.from_string(myid), "", False) hs.get_v1auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 93896dd076..b43563fa4b 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -54,10 +54,11 @@ class RoomPermissionsTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -439,10 +440,11 @@ class RoomsMemberListTestCase(RestTestCase): self.auth_user_id = self.user_id - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -517,10 +519,11 @@ class RoomsCreateTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -608,10 +611,11 @@ class RoomTopicTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -713,10 +717,11 @@ class RoomMemberStateTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -838,10 +843,11 @@ class RoomMessagesTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token @@ -933,10 +939,11 @@ class RoomInitialSyncTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 6395ce79db..8433585616 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -61,10 +61,11 @@ class RoomTypingTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, + "is_guest": False, } hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index f45570a1c0..fa9e17ec4f 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -43,10 +43,11 @@ class V2AlphaRestTestCase(unittest.TestCase): resource_for_federation=self.mock_resource, ) - def _get_user_by_access_token(token=None): + def _get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.USER_ID), "token_id": 1, + "is_guest": False, } hs.get_auth()._get_user_by_access_token = _get_user_by_access_token -- cgit 1.5.1 From ca2f90742d5606f8fc5b7ddd3dd7244c377c1df8 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 5 Nov 2015 14:32:26 +0000 Subject: Open up /events to anonymous users for room events only Squash-merge of PR #345 from daniel/anonymousevents --- synapse/handlers/_base.py | 7 ++- synapse/handlers/events.py | 10 ++- synapse/handlers/message.py | 47 ++++++++++---- synapse/handlers/presence.py | 4 +- synapse/handlers/private_user_data.py | 2 +- synapse/handlers/receipts.py | 6 +- synapse/handlers/room.py | 11 +++- synapse/handlers/sync.py | 20 +++++- synapse/handlers/typing.py | 11 +--- synapse/notifier.py | 42 ++++++++++--- synapse/rest/client/v1/events.py | 13 +++- synapse/rest/client/v1/room.py | 6 +- synapse/storage/events.py | 2 + synapse/storage/room.py | 13 ++++ .../storage/schema/delta/25/history_visibility.sql | 26 ++++++++ synapse/storage/stream.py | 46 +++++++++++--- tests/handlers/test_presence.py | 71 ++++++++++++++++------ tests/handlers/test_typing.py | 30 +++++++-- tests/rest/client/v1/test_presence.py | 9 ++- tests/rest/client/v1/test_typing.py | 5 +- 20 files changed, 299 insertions(+), 82 deletions(-) create mode 100644 synapse/storage/schema/delta/25/history_visibility.sql (limited to 'tests') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 6873a4575d..a9e43052b7 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -47,7 +47,8 @@ class BaseHandler(object): self.event_builder_factory = hs.get_event_builder_factory() @defer.inlineCallbacks - def _filter_events_for_client(self, user_id, events, is_guest=False): + def _filter_events_for_client(self, user_id, events, is_guest=False, + require_all_visible_for_guests=True): # Assumes that user has at some point joined the room if not is_guest. def allowed(event, membership, visibility): @@ -100,7 +101,9 @@ class BaseHandler(object): if should_include: events_to_return.append(event) - if is_guest and len(events_to_return) < len(events): + if (require_all_visible_for_guests + and is_guest + and len(events_to_return) < len(events)): # This indicates that some events in the requested range were not # visible to guest users. To be safe, we reject the entire request, # so that we don't have to worry about interpreting visibility diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 53c8ca3a26..0e4c0d4d06 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -100,7 +100,7 @@ class EventStreamHandler(BaseHandler): @log_function def get_stream(self, auth_user_id, pagin_config, timeout=0, as_client_event=True, affect_presence=True, - only_room_events=False): + only_room_events=False, room_id=None, is_guest=False): """Fetches the events stream for a given user. If `only_room_events` is `True` only room events will be returned. @@ -119,9 +119,15 @@ class EventStreamHandler(BaseHandler): # thundering herds on restart. timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) + if is_guest: + yield self.distributor.fire( + "user_joined_room", user=auth_user, room_id=room_id + ) + events, tokens = yield self.notifier.get_events_for( auth_user, pagin_config, timeout, - only_room_events=only_room_events + only_room_events=only_room_events, + is_guest=is_guest, guest_room_id=room_id ) time_now = self.clock.time_msec() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 687e1527f7..654ecd2b37 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, AuthError, Codes from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator @@ -229,7 +229,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_room_data(self, user_id=None, room_id=None, - event_type=None, state_key=""): + event_type=None, state_key="", is_guest=False): """ Get data from a room. Args: @@ -239,23 +239,42 @@ class MessageHandler(BaseHandler): Raises: SynapseError if something went wrong. """ - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + membership, membership_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id, is_guest + ) - if member_event.membership == Membership.JOIN: + if membership == Membership.JOIN: data = yield self.state_handler.get_current_state( room_id, event_type, state_key ) - elif member_event.membership == Membership.LEAVE: + elif membership == Membership.LEAVE: key = (event_type, state_key) room_state = yield self.store.get_state_for_events( - [member_event.event_id], [key] + [membership_event_id], [key] ) - data = room_state[member_event.event_id].get(key) + data = room_state[membership_event_id].get(key) defer.returnValue(data) @defer.inlineCallbacks - def get_state_events(self, user_id, room_id): + def _check_in_room_or_world_readable(self, room_id, user_id, is_guest): + if is_guest: + visibility = yield self.state_handler.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" + ) + if visibility.content["history_visibility"] == "world_readable": + defer.returnValue((Membership.JOIN, None)) + return + else: + raise AuthError( + 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + ) + else: + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + defer.returnValue((member_event.membership, member_event.event_id)) + + @defer.inlineCallbacks + def get_state_events(self, user_id, room_id, is_guest=False): """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has left the room return the state events from when they left. @@ -266,15 +285,17 @@ class MessageHandler(BaseHandler): Returns: A list of dicts representing state events. [{}, {}, {}] """ - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + membership, membership_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id, is_guest + ) - if member_event.membership == Membership.JOIN: + if membership == Membership.JOIN: room_state = yield self.state_handler.get_current_state(room_id) - elif member_event.membership == Membership.LEAVE: + elif membership == Membership.LEAVE: room_state = yield self.store.get_state_for_events( - [member_event.event_id], None + [membership_event_id], None ) - room_state = room_state[member_event.event_id] + room_state = room_state[membership_event_id] now = self.clock.time_msec() defer.returnValue( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index ce60642127..0b780cd528 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1142,8 +1142,9 @@ class PresenceEventSource(object): @defer.inlineCallbacks @log_function - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events(self, user, from_key, room_ids=None, **kwargs): from_key = int(from_key) + room_ids = room_ids or [] presence = self.hs.get_handlers().presence_handler cachemap = presence._user_cachemap @@ -1161,7 +1162,6 @@ class PresenceEventSource(object): user_ids_to_check |= set( UserID.from_string(p["observed_user_id"]) for p in presence_list ) - room_ids = yield presence.get_joined_rooms_for_user(user) for room_id in set(room_ids) & set(presence._room_serials): if presence._room_serials[room_id] > from_key: joined = yield presence.get_joined_users_for_room_id(room_id) diff --git a/synapse/handlers/private_user_data.py b/synapse/handlers/private_user_data.py index 1778c71325..1abe45ed7b 100644 --- a/synapse/handlers/private_user_data.py +++ b/synapse/handlers/private_user_data.py @@ -24,7 +24,7 @@ class PrivateUserDataEventSource(object): return self.store.get_max_private_user_data_stream_id() @defer.inlineCallbacks - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events(self, user, from_key, **kwargs): user_id = user.to_string() last_stream_id = from_key diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index a47ae3df42..973f4d5cae 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -164,17 +164,15 @@ class ReceiptEventSource(object): self.store = hs.get_datastore() @defer.inlineCallbacks - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events(self, from_key, room_ids, **kwargs): from_key = int(from_key) to_key = yield self.get_current_key() if from_key == to_key: defer.returnValue(([], to_key)) - rooms = yield self.store.get_rooms_for_user(user.to_string()) - rooms = [room.room_id for room in rooms] events = yield self.store.get_linearized_receipts_for_rooms( - rooms, + room_ids, from_key=from_key, to_key=to_key, ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9184dcd048..736ffe9066 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -807,7 +807,14 @@ class RoomEventSource(object): self.store = hs.get_datastore() @defer.inlineCallbacks - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events( + self, + user, + from_key, + limit, + room_ids, + is_guest, + ): # We just ignore the key for now. to_key = yield self.get_current_key() @@ -828,6 +835,8 @@ class RoomEventSource(object): from_key=from_key, to_key=to_key, limit=limit, + room_ids=room_ids, + is_guest=is_guest, ) defer.returnValue((events, end_key)) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1c1ee34b1e..5294d96466 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -295,11 +295,16 @@ class SyncHandler(BaseHandler): typing_key = since_token.typing_key if since_token else "0" + rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) + room_ids = [room.room_id for room in rooms] + typing_source = self.event_sources.sources["typing"] - typing, typing_key = yield typing_source.get_new_events_for_user( + typing, typing_key = yield typing_source.get_new_events( user=sync_config.user, from_key=typing_key, limit=sync_config.filter.ephemeral_limit(), + room_ids=room_ids, + is_guest=False, ) now_token = now_token.copy_and_replace("typing_key", typing_key) @@ -312,10 +317,13 @@ class SyncHandler(BaseHandler): receipt_key = since_token.receipt_key if since_token else "0" receipt_source = self.event_sources.sources["receipt"] - receipts, receipt_key = yield receipt_source.get_new_events_for_user( + receipts, receipt_key = yield receipt_source.get_new_events( user=sync_config.user, from_key=receipt_key, limit=sync_config.filter.ephemeral_limit(), + room_ids=room_ids, + # /sync doesn't support guest access, they can't get to this point in code + is_guest=False, ) now_token = now_token.copy_and_replace("receipt_key", receipt_key) @@ -360,11 +368,17 @@ class SyncHandler(BaseHandler): """ now_token = yield self.event_sources.get_current_token() + rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) + room_ids = [room.room_id for room in rooms] + presence_source = self.event_sources.sources["presence"] - presence, presence_key = yield presence_source.get_new_events_for_user( + presence, presence_key = yield presence_source.get_new_events( user=sync_config.user, from_key=since_token.presence_key, limit=sync_config.filter.presence_limit(), + room_ids=room_ids, + # /sync doesn't support guest access, they can't get to this point in code + is_guest=False, ) now_token = now_token.copy_and_replace("presence_key", presence_key) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index d7096aab8c..2846f3e6e8 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -246,17 +246,12 @@ class TypingNotificationEventSource(object): }, } - @defer.inlineCallbacks - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events(self, from_key, room_ids, **kwargs): from_key = int(from_key) handler = self.handler() - joined_room_ids = ( - yield self.room_member_handler().get_joined_rooms_for_user(user) - ) - events = [] - for room_id in joined_room_ids: + for room_id in room_ids: if room_id not in handler._room_serials: continue if handler._room_serials[room_id] <= from_key: @@ -264,7 +259,7 @@ class TypingNotificationEventSource(object): events.append(self._make_event_for(room_id)) - defer.returnValue((events, handler._latest_room_serial)) + return events, handler._latest_room_serial def get_current_key(self): return self.handler()._latest_room_serial diff --git a/synapse/notifier.py b/synapse/notifier.py index b69da63d43..56c4c863b5 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -269,7 +269,7 @@ class Notifier(object): logger.exception("Failed to notify listener") @defer.inlineCallbacks - def wait_for_events(self, user, timeout, callback, + def wait_for_events(self, user, timeout, callback, room_ids=None, from_token=StreamToken("s0", "0", "0", "0", "0")): """Wait until the callback returns a non empty response or the timeout fires. @@ -279,11 +279,12 @@ class Notifier(object): if user_stream is None: appservice = yield self.store.get_app_service_by_user_id(user) current_token = yield self.event_sources.get_current_token() - rooms = yield self.store.get_rooms_for_user(user) - rooms = [room.room_id for room in rooms] + if room_ids is None: + rooms = yield self.store.get_rooms_for_user(user) + room_ids = [room.room_id for room in rooms] user_stream = _NotifierUserStream( user=user, - rooms=rooms, + rooms=room_ids, appservice=appservice, current_token=current_token, time_now_ms=self.clock.time_msec(), @@ -329,7 +330,8 @@ class Notifier(object): @defer.inlineCallbacks def get_events_for(self, user, pagination_config, timeout, - only_room_events=False): + only_room_events=False, + is_guest=False, guest_room_id=None): """ For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any new events to happen before returning. @@ -342,6 +344,16 @@ class Notifier(object): limit = pagination_config.limit + room_ids = [] + if is_guest: + # TODO(daniel): Deal with non-room events too + only_room_events = True + if guest_room_id: + room_ids = [guest_room_id] + else: + rooms = yield self.store.get_rooms_for_user(user.to_string()) + room_ids = [room.room_id for room in rooms] + @defer.inlineCallbacks def check_for_updates(before_token, after_token): if not after_token.is_after(before_token): @@ -357,9 +369,23 @@ class Notifier(object): continue if only_room_events and name != "room": continue - new_events, new_key = yield source.get_new_events_for_user( - user, getattr(from_token, keyname), limit, + new_events, new_key = yield source.get_new_events( + user=user, + from_key=getattr(from_token, keyname), + limit=limit, + is_guest=is_guest, + room_ids=room_ids, ) + + if is_guest: + room_member_handler = self.hs.get_handlers().room_member_handler + new_events = yield room_member_handler._filter_events_for_client( + user.to_string(), + new_events, + is_guest=is_guest, + require_all_visible_for_guests=False + ) + events.extend(new_events) end_token = end_token.copy_and_replace(keyname, new_key) @@ -369,7 +395,7 @@ class Notifier(object): defer.returnValue(None) result = yield self.wait_for_events( - user, timeout, check_for_updates, from_token=from_token + user, timeout, check_for_updates, room_ids=room_ids, from_token=from_token ) if result is None: diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 4073b0d2d1..3e1750d1a1 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -34,7 +34,15 @@ class EventStreamRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user, _, _ = yield self.auth.get_user_by_req(request) + auth_user, _, is_guest = yield self.auth.get_user_by_req( + request, + allow_guest=True + ) + room_id = None + if is_guest: + if "room_id" not in request.args: + raise SynapseError(400, "Guest users must specify room_id param") + room_id = request.args["room_id"][0] try: handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) @@ -49,7 +57,8 @@ class EventStreamRestServlet(ClientV1RestServlet): chunk = yield handler.get_stream( auth_user.to_string(), pagin_config, timeout=timeout, - as_client_event=as_client_event + as_client_event=as_client_event, affect_presence=(not is_guest), + room_id=room_id, is_guest=is_guest ) except: logger.exception("Event stream failed") diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 0876e593c5..afb802baec 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): - user, _, _ = yield self.auth.get_user_by_req(request) + user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( @@ -133,6 +133,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): room_id=room_id, event_type=event_type, state_key=state_key, + is_guest=is_guest, ) if not data: @@ -348,12 +349,13 @@ class RoomStateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user, _, _ = yield self.auth.get_user_by_req(request) + user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True) handler = self.handlers.message_handler # Get all the current state for this room events = yield handler.get_state_events( room_id=room_id, user_id=user.to_string(), + is_guest=is_guest, ) defer.returnValue((200, events)) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index e6c1abfc27..59c9987202 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -311,6 +311,8 @@ class EventsStore(SQLBaseStore): self._store_room_message_txn(txn, event) elif event.type == EventTypes.Redaction: self._store_redaction(txn, event) + elif event.type == EventTypes.RoomHistoryVisibility: + self._store_history_visibility_txn(txn, event) self._store_room_members_txn( txn, diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 13441fcdce..1c79626736 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -202,6 +202,19 @@ class RoomStore(SQLBaseStore): txn, event, "content.body", event.content["body"] ) + def _store_history_visibility_txn(self, txn, event): + if hasattr(event, "content") and "history_visibility" in event.content: + sql = ( + "INSERT INTO history_visibility" + " (event_id, room_id, history_visibility)" + " VALUES (?, ?, ?)" + ) + txn.execute(sql, ( + event.event_id, + event.room_id, + event.content["history_visibility"] + )) + def _store_event_search_txn(self, txn, event, key, value): if isinstance(self.database_engine, PostgresEngine): sql = ( diff --git a/synapse/storage/schema/delta/25/history_visibility.sql b/synapse/storage/schema/delta/25/history_visibility.sql new file mode 100644 index 0000000000..9f387ed69f --- /dev/null +++ b/synapse/storage/schema/delta/25/history_visibility.sql @@ -0,0 +1,26 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This is a manual index of history_visibility content of state events, + * so that we can join on them in SELECT statements. + */ +CREATE TABLE IF NOT EXISTS history_visibility( + id INTEGER PRIMARY KEY, + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + history_visibility TEXT NOT NULL, + UNIQUE (event_id) +); diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index c728013f4c..be8ba76aae 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -158,13 +158,40 @@ class StreamStore(SQLBaseStore): defer.returnValue(results) @log_function - def get_room_events_stream(self, user_id, from_key, to_key, limit=0): - current_room_membership_sql = ( - "SELECT m.room_id FROM room_memberships as m " - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id AND c.state_key = m.user_id" - " WHERE m.user_id = ? AND m.membership = 'join'" - ) + def get_room_events_stream( + self, + user_id, + from_key, + to_key, + limit=0, + is_guest=False, + room_ids=None + ): + room_ids = room_ids or [] + room_ids = [r for r in room_ids] + if is_guest: + current_room_membership_sql = ( + "SELECT c.room_id FROM history_visibility AS h" + " INNER JOIN current_state_events AS c" + " ON h.event_id = c.event_id" + " WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % ( + ",".join(map(lambda _: "?", room_ids)) + ) + ) + current_room_membership_args = room_ids + else: + current_room_membership_sql = ( + "SELECT m.room_id FROM room_memberships as m " + " INNER JOIN current_state_events as c" + " ON m.event_id = c.event_id AND c.state_key = m.user_id" + " WHERE m.user_id = ? AND m.membership = 'join'" + ) + current_room_membership_args = [user_id] + if room_ids: + current_room_membership_sql += " AND m.room_id in (%s)" % ( + ",".join(map(lambda _: "?", room_ids)) + ) + current_room_membership_args = [user_id] + room_ids # We also want to get any membership events about that user, e.g. # invites or leave notifications. @@ -173,6 +200,7 @@ class StreamStore(SQLBaseStore): "INNER JOIN current_state_events as c ON m.event_id = c.event_id " "WHERE m.user_id = ? " ) + membership_args = [user_id] if limit: limit = max(limit, MAX_STREAM_SIZE) @@ -199,7 +227,9 @@ class StreamStore(SQLBaseStore): } def f(txn): - txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,)) + args = ([False] + current_room_membership_args + membership_args + + [from_id.stream, to_id.stream]) + txn.execute(sql, args) rows = self.cursor_to_dict(txn) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 29372d488a..10d4482cce 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -650,9 +650,30 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): {"presence": ONLINE} ) + # Apple sees self-reflection even without room_id + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, + ) + + self.assertEquals(self.event_source.get_current_key(), 1) + self.assertEquals(events, + [ + {"type": "m.presence", + "content": { + "user_id": "@apple:test", + "presence": ONLINE, + "last_active_ago": 0, + }}, + ], + msg="Presence event should be visible to self-reflection" + ) + # Apple sees self-reflection - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, + room_ids=[self.room_id], ) self.assertEquals(self.event_source.get_current_key(), 1) @@ -684,8 +705,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): ) # Banana sees it because of presence subscription - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_banana, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_banana, + from_key=0, + room_ids=[self.room_id], ) self.assertEquals(self.event_source.get_current_key(), 1) @@ -702,8 +725,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): ) # Elderberry sees it because of same room - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_elderberry, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_elderberry, + from_key=0, + room_ids=[self.room_id], ) self.assertEquals(self.event_source.get_current_key(), 1) @@ -720,8 +745,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): ) # Durian is not in the room, should not see this event - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_durian, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_durian, + from_key=0, + room_ids=[], ) self.assertEquals(self.event_source.get_current_key(), 1) @@ -767,8 +794,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): "accepted": True}, ], presence) - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 1, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=1, ) self.assertEquals(self.event_source.get_current_key(), 2) @@ -858,8 +886,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): ) ) - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, + room_ids=[self.room_id], ) self.assertEquals(self.event_source.get_current_key(), 1) @@ -905,8 +935,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): self.assertEquals(self.event_source.get_current_key(), 1) - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, + room_ids=[self.room_id,] ) self.assertEquals(events, [ @@ -932,8 +964,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): self.assertEquals(self.event_source.get_current_key(), 2) - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, + room_ids=[self.room_id,] ) self.assertEquals(events, [ @@ -966,8 +1000,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase): self.room_members.append(self.u_clementine) - (events, _) = yield self.event_source.get_new_events_for_user( - self.u_apple, 0, None + (events, _) = yield self.event_source.get_new_events( + user=self.u_apple, + from_key=0, ) self.assertEquals(self.event_source.get_current_key(), 1) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 41bb08b7ca..2d7ba43561 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -187,7 +187,10 @@ class TypingNotificationsTestCase(unittest.TestCase): ]) self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=0, + ) self.assertEquals( events[0], [ @@ -250,7 +253,10 @@ class TypingNotificationsTestCase(unittest.TestCase): ]) self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=0 + ) self.assertEquals( events[0], [ @@ -306,7 +312,10 @@ class TypingNotificationsTestCase(unittest.TestCase): yield put_json.await_calls() self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=0, + ) self.assertEquals( events[0], [ @@ -337,7 +346,10 @@ class TypingNotificationsTestCase(unittest.TestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=0, + ) self.assertEquals( events[0], [ @@ -356,7 +368,10 @@ class TypingNotificationsTestCase(unittest.TestCase): ]) self.assertEquals(self.event_source.get_current_key(), 2) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 1, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=1, + ) self.assertEquals( events[0], [ @@ -383,7 +398,10 @@ class TypingNotificationsTestCase(unittest.TestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) - events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) + events = yield self.event_source.get_new_events( + room_ids=[self.room_id], + from_key=0, + ) self.assertEquals( events[0], [ diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 3e0f294630..7f29d73d95 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -47,7 +47,14 @@ class NullSource(object): def __init__(self, hs): pass - def get_new_events_for_user(self, user, from_key, limit): + def get_new_events( + self, + user, + from_key, + room_ids=None, + limit=None, + is_guest=None + ): return defer.succeed(([], from_key)) def get_current_key(self, direction='f'): diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 8433585616..61b9cc743b 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -116,7 +116,10 @@ class RoomTypingTestCase(RestTestCase): self.assertEquals(200, code) self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events_for_user(self.user, 0, None) + events = yield self.event_source.get_new_events( + from_key=0, + room_ids=[self.room_id], + ) self.assertEquals( events[0], [ -- cgit 1.5.1 From 9107ed23b73b76347a63a2a2eea4e41f30f02062 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 5 Nov 2015 16:56:40 +0000 Subject: Add a couple of unit tests for room//messages ... merely because I was trying to figure out how it worked, and couldn't. --- synapse/rest/client/v1/room.py | 2 +- tests/rest/client/v1/test_rooms.py | 56 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 6e0d93766b..f7012067f7 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -319,7 +319,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): })) -# TODO: Needs unit testing +# TODO: Needs better unit testing class RoomMessageListRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/rooms/(?P[^/]*)/messages$") diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index b43563fa4b..7749378064 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -994,3 +994,59 @@ class RoomInitialSyncTestCase(RestTestCase): } self.assertTrue(self.user_id in presence_by_user) self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) + + +class RoomMessageListTestCase(RestTestCase): + """ Tests /rooms/$room_id/messages REST events. """ + user_id = "@sid1:red" + + @defer.inlineCallbacks + def setUp(self): + self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) + self.auth_user_id = self.user_id + + hs = yield setup_test_homeserver( + "red", + http_client=None, + replication_layer=Mock(), + ratelimiter=NonCallableMock(spec_set=["send_message"]), + ) + self.ratelimiter = hs.get_ratelimiter() + self.ratelimiter.send_message.return_value = (True, 0) + + hs.get_handlers().federation_handler = Mock() + + def _get_user_by_access_token(token=None, allow_guest=False): + return { + "user": UserID.from_string(self.auth_user_id), + "token_id": 1, + "is_guest": False, + } + hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token + + def _insert_client_ip(*args, **kwargs): + return defer.succeed(None) + hs.get_datastore().insert_client_ip = _insert_client_ip + + synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) + + self.room_id = yield self.create_room_as(self.user_id) + + @defer.inlineCallbacks + def test_topo_token_is_accepted(self): + token = "t1-0_0_0_0_0" + (code, response) = yield self.mock_resource.trigger_get( + "/rooms/%s/messages?access_token=x&from=%s" % + (self.room_id, token)) + self.assertEquals(200, code) + self.assertTrue("start" in response) + self.assertEquals(token, response['start']) + self.assertTrue("chunk" in response) + self.assertTrue("end" in response) + + @defer.inlineCallbacks + def test_stream_token_is_rejected(self): + (code, response) = yield self.mock_resource.trigger_get( + "/rooms/%s/messages?access_token=x&from=s0_0_0_0" % + self.room_id) + self.assertEquals(400, code) -- cgit 1.5.1 From 36c58b18a32f05a2f025bc916c14b9e2f78f439b Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 10 Nov 2015 15:51:40 +0000 Subject: Test for background updates --- tests/storage/test_background_update.py | 76 +++++++++++++++++++++++++++++++++ tests/utils.py | 3 ++ 2 files changed, 79 insertions(+) create mode 100644 tests/storage/test_background_update.py (limited to 'tests') diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py new file mode 100644 index 0000000000..29289fa9b4 --- /dev/null +++ b/tests/storage/test_background_update.py @@ -0,0 +1,76 @@ +from tests import unittest +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.types import UserID, RoomID, RoomAlias + +from tests.utils import setup_test_homeserver + +from mock import Mock + +class BackgroundUpdateTestCase(unittest.TestCase): + + @defer.inlineCallbacks + def setUp(self): + hs = yield setup_test_homeserver() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + self.update_handler = Mock() + + yield self.store.register_background_update_handler( + "test_update", self.update_handler + ) + + @defer.inlineCallbacks + def test_do_background_update(self): + desired_count = 1000; + duration_ms = 42; + + @defer.inlineCallbacks + def update(progress, count): + self.clock.advance_time_msec(count * duration_ms) + progress = {"my_key": progress["my_key"] + 1} + yield self.store.runInteraction( + "update_progress", + self.store._background_update_progress_txn, + "test_update", + progress, + ) + defer.returnValue(count) + + self.update_handler.side_effect = update + + yield self.store.start_background_update("test_update", {"my_key": 1}) + + self.update_handler.reset_mock() + result = yield self.store.do_background_update( + duration_ms * desired_count + ) + self.assertIsNotNone(result) + self.update_handler.assert_called_once_with( + {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE + ) + + @defer.inlineCallbacks + def update(progress, count): + yield self.store._end_background_update("test_update") + defer.returnValue(count) + + self.update_handler.side_effect = update + + self.update_handler.reset_mock() + result = yield self.store.do_background_update( + duration_ms * desired_count + ) + self.assertIsNotNone(result) + self.update_handler.assert_called_once_with( + {"my_key": 2}, desired_count + ) + + self.update_handler.reset_mock() + result = yield self.store.do_background_update( + duration_ms * desired_count + ) + self.assertIsNone(result) + self.assertFalse(self.update_handler.called) diff --git a/tests/utils.py b/tests/utils.py index ca2c33cf8e..91040c2efd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -243,6 +243,9 @@ class MockClock(object): else: self.timers.append(t) + def advance_time_msec(self, ms): + self.advance_time(ms / 1000.) + class SQLiteMemoryDbPool(ConnectionPool, object): def __init__(self): -- cgit 1.5.1 From cf437900e0c689aad40f3da89866cf84c0f7ef65 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 10 Nov 2015 17:10:27 +0000 Subject: Return world_readable and guest_can_join in /publicRooms --- synapse/storage/events.py | 2 + synapse/storage/room.py | 71 ++++++++++++++---------- synapse/storage/schema/delta/25/guest_access.sql | 25 +++++++++ tests/storage/test_room.py | 2 + 4 files changed, 71 insertions(+), 29 deletions(-) create mode 100644 synapse/storage/schema/delta/25/guest_access.sql (limited to 'tests') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 59c9987202..4a365ff639 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -313,6 +313,8 @@ class EventsStore(SQLBaseStore): self._store_redaction(txn, event) elif event.type == EventTypes.RoomHistoryVisibility: self._store_history_visibility_txn(txn, event) + elif event.type == EventTypes.GuestAccess: + self._store_guest_access_txn(txn, event) self._store_room_members_txn( txn, diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 1c79626736..4f08df478c 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -99,34 +99,39 @@ class RoomStore(SQLBaseStore): """ def f(txn): - topic_subquery = ( - "SELECT topics.event_id as event_id, " - "topics.room_id as room_id, topic " - "FROM topics " - "INNER JOIN current_state_events as c " - "ON c.event_id = topics.event_id " - ) - - name_subquery = ( - "SELECT room_names.event_id as event_id, " - "room_names.room_id as room_id, name " - "FROM room_names " - "INNER JOIN current_state_events as c " - "ON c.event_id = room_names.event_id " - ) + def subquery(table_name, column_name=None): + column_name = column_name or table_name + return ( + "SELECT %(table_name)s.event_id as event_id, " + "%(table_name)s.room_id as room_id, %(column_name)s " + "FROM %(table_name)s " + "INNER JOIN current_state_events as c " + "ON c.event_id = %(table_name)s.event_id " % { + "column_name": column_name, + "table_name": table_name, + } + ) - # We use non printing ascii character US (\x1F) as a separator sql = ( - "SELECT r.room_id, max(n.name), max(t.topic)" + "SELECT" + " r.room_id," + " max(n.name)," + " max(t.topic)," + " max(v.history_visibility)," + " max(g.guest_access)" " FROM rooms AS r" " LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id" " LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id" + " LEFT JOIN (%(history_visibility)s) AS v ON v.room_id = r.room_id" + " LEFT JOIN (%(guest_access)s) AS g ON g.room_id = r.room_id" " WHERE r.is_public = ?" - " GROUP BY r.room_id" - ) % { - "topic": topic_subquery, - "name": name_subquery, - } + " GROUP BY r.room_id" % { + "topic": subquery("topics", "topic"), + "name": subquery("room_names", "name"), + "history_visibility": subquery("history_visibility"), + "guest_access": subquery("guest_access"), + } + ) txn.execute(sql, (is_public,)) @@ -156,10 +161,12 @@ class RoomStore(SQLBaseStore): "room_id": r[0], "name": r[1], "topic": r[2], - "aliases": r[3], + "world_readable": r[3] == "world_readable", + "guest_can_join": r[4] == "can_join", + "aliases": r[5], } for r in rows - if r[3] # We only return rooms that have at least one alias. + if r[5] # We only return rooms that have at least one alias. ] defer.returnValue(ret) @@ -203,16 +210,22 @@ class RoomStore(SQLBaseStore): ) def _store_history_visibility_txn(self, txn, event): - if hasattr(event, "content") and "history_visibility" in event.content: + self._store_content_index_txn(txn, event, "history_visibility") + + def _store_guest_access_txn(self, txn, event): + self._store_content_index_txn(txn, event, "guest_access") + + def _store_content_index_txn(self, txn, event, key): + if hasattr(event, "content") and key in event.content: sql = ( - "INSERT INTO history_visibility" - " (event_id, room_id, history_visibility)" - " VALUES (?, ?, ?)" + "INSERT INTO %(key)s" + " (event_id, room_id, %(key)s)" + " VALUES (?, ?, ?)" % {"key": key} ) txn.execute(sql, ( event.event_id, event.room_id, - event.content["history_visibility"] + event.content[key] )) def _store_event_search_txn(self, txn, event, key, value): diff --git a/synapse/storage/schema/delta/25/guest_access.sql b/synapse/storage/schema/delta/25/guest_access.sql new file mode 100644 index 0000000000..bdb90e7118 --- /dev/null +++ b/synapse/storage/schema/delta/25/guest_access.sql @@ -0,0 +1,25 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This is a manual index of guest_access content of state events, + * so that we can join on them in SELECT statements. + */ +CREATE TABLE IF NOT EXISTS guest_access( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + guest_access TEXT NOT NULL, + UNIQUE (event_id) +); diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index caffce64e3..91c967548d 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -73,6 +73,8 @@ class RoomStoreTestCase(unittest.TestCase): "room_id": self.room.to_string(), "topic": None, "aliases": [self.alias.to_string()], + "world_readable": False, + "guest_can_join": False, }, rooms[0]) -- cgit 1.5.1 From 78f6010207d5e6908ba584121461af4b02714287 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 12 Nov 2015 13:10:25 +0000 Subject: Fix an issue with ignoring power_level changes on divergent graphs Changes to m.room.power_levels events are supposed to be handled at a high priority; however a typo meant that the relevant bit of code was never executed, so they were handled just like any other state change - which meant that a bad person could cause room state changes by forking the graph from a point in history when they were allowed to do so. --- synapse/state.py | 16 ++++++--- tests/test_state.py | 93 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/synapse/state.py b/synapse/state.py index bb225c39cf..f893df3378 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -307,19 +307,23 @@ class StateHandler(object): We resolve conflicts in the following order: 1. power levels - 2. memberships - 3. other events. + 2. join rules + 3. memberships + 4. other events. """ resolved_state = {} power_key = (EventTypes.PowerLevels, "") - if power_key in conflicted_state.items(): - power_levels = conflicted_state[power_key] - resolved_state[power_key] = self._resolve_auth_events(power_levels) + if power_key in conflicted_state: + events = conflicted_state[power_key] + logger.debug("Resolving conflicted power levels %r", events) + resolved_state[power_key] = self._resolve_auth_events( + events, auth_events) auth_events.update(resolved_state) for key, events in conflicted_state.items(): if key[0] == EventTypes.JoinRules: + logger.debug("Resolving conflicted join rules %r", events) resolved_state[key] = self._resolve_auth_events( events, auth_events @@ -329,6 +333,7 @@ class StateHandler(object): for key, events in conflicted_state.items(): if key[0] == EventTypes.Member: + logger.debug("Resolving conflicted member lists %r", events) resolved_state[key] = self._resolve_auth_events( events, auth_events @@ -338,6 +343,7 @@ class StateHandler(object): for key, events in conflicted_state.items(): if key not in resolved_state: + logger.debug("Resolving conflicted state %r:%r", key, events) resolved_state[key] = self._resolve_normal_events( events, auth_events ) diff --git a/tests/test_state.py b/tests/test_state.py index 0274c4bc18..e4e995b756 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -317,6 +317,99 @@ class StateTestCase(unittest.TestCase): {e.event_id for e in context_store["E"].current_state.values()} ) + @defer.inlineCallbacks + def test_branch_have_perms_conflict(self): + userid1 = "@user_id:example.com" + userid2 = "@user_id2:example.com" + + nodes = { + "A1": DictObj( + type=EventTypes.Create, + state_key="", + content={"creator": userid1}, + depth=1, + ), + "A2": DictObj( + type=EventTypes.Member, + state_key=userid1, + content={"membership": Membership.JOIN}, + membership=Membership.JOIN, + ), + "A3": DictObj( + type=EventTypes.Member, + state_key=userid2, + content={"membership": Membership.JOIN}, + membership=Membership.JOIN, + ), + "A4": DictObj( + type=EventTypes.PowerLevels, + state_key="", + content={ + "events": {"m.room.name": 50}, + "users": {userid1: 100, + userid2: 60}, + }, + ), + "A5": DictObj( + type=EventTypes.Name, + state_key="", + ), + "B": DictObj( + type=EventTypes.PowerLevels, + state_key="", + content={ + "events": {"m.room.name": 50}, + "users": {userid2: 30}, + }, + ), + "C": DictObj( + type=EventTypes.Name, + state_key="", + sender=userid2, + ), + "D": DictObj( + type=EventTypes.Message, + ), + } + edges = { + "A2": ["A1"], + "A3": ["A2"], + "A4": ["A3"], + "A5": ["A4"], + "B": ["A5"], + "C": ["A5"], + "D": ["B", "C"] + } + self._add_depths(nodes, edges) + graph = Graph(nodes, edges) + + store = StateGroupStore() + self.store.get_state_groups.side_effect = store.get_state_groups + + context_store = {} + + for event in graph.walk(): + context = yield self.state.compute_event_context(event) + store.store_state_groups(event, context) + context_store[event.event_id] = context + + self.assertSetEqual( + {"A1", "A2", "A3", "A5", "B"}, + {e.event_id for e in context_store["D"].current_state.values()} + ) + + def _add_depths(self, nodes, edges): + def _get_depth(ev): + node = nodes[ev] + if 'depth' not in node: + prevs = edges[ev] + depth = max(_get_depth(prev) for prev in prevs) + 1 + node['depth'] = depth + return node['depth'] + + for n in nodes: + _get_depth(n) + @defer.inlineCallbacks def test_annotate_with_old_message(self): event = create_event(type="test_message", name="event") -- cgit 1.5.1 From 468a2ed4ecd06b208611d3b44cd588a184efdfec Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 12 Nov 2015 16:45:28 +0000 Subject: Return non-room events from guest /events calls --- synapse/notifier.py | 20 +++++++++++++++++--- tests/rest/client/v1/test_presence.py | 3 +++ 2 files changed, 20 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/synapse/notifier.py b/synapse/notifier.py index 56c4c863b5..e3b42e2331 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -14,6 +14,8 @@ # limitations under the License. from twisted.internet import defer +from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor, ObservableDeferred @@ -346,9 +348,9 @@ class Notifier(object): room_ids = [] if is_guest: - # TODO(daniel): Deal with non-room events too - only_room_events = True if guest_room_id: + if not self._is_world_readable(guest_room_id): + raise AuthError(403, "Guest access not allowed") room_ids = [guest_room_id] else: rooms = yield self.store.get_rooms_for_user(user.to_string()) @@ -361,6 +363,7 @@ class Notifier(object): events = [] end_token = from_token + for name, source in self.event_sources.sources.items(): keyname = "%s_key" % name before_id = getattr(before_token, keyname) @@ -377,7 +380,7 @@ class Notifier(object): room_ids=room_ids, ) - if is_guest: + if name == "room": room_member_handler = self.hs.get_handlers().room_member_handler new_events = yield room_member_handler._filter_events_for_client( user.to_string(), @@ -403,6 +406,17 @@ class Notifier(object): defer.returnValue(result) + @defer.inlineCallbacks + def _is_world_readable(self, room_id): + state = yield self.hs.get_state_handler().get_current_state( + room_id, + EventTypes.RoomHistoryVisibility + ) + if state and "history_visibility" in state.content: + defer.returnValue(state.content["history_visibility"] == "world_readable") + else: + defer.returnValue(False) + @log_function def remove_expired_streams(self): time_now_ms = self.clock.time_msec() diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 7f29d73d95..8581796f72 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -321,6 +321,9 @@ class PresenceEventStreamTestCase(unittest.TestCase): hs.handlers.room_member_handler.get_room_members = ( lambda r: self.room_members if r == "a-room" else [] ) + hs.handlers.room_member_handler._filter_events_for_client = ( + lambda user_id, events, **kwargs: events + ) self.mock_datastore = hs.get_datastore() self.mock_datastore.get_app_service_by_token = Mock(return_value=None) -- cgit 1.5.1