From 6fe04ffef2fd44a3af61fe262266b25158c46db7 Mon Sep 17 00:00:00 2001 From: Negi Fazeli Date: Mon, 23 May 2016 11:14:11 +0200 Subject: Fix set profile error with Requester. Replace flush_user with delete access token due to function removal Add a new test case for if the user is already registered --- tests/handlers/test_register.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 8b7be96bd9..9d5c653b45 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,6 +17,7 @@ from twisted.internet import defer from .. import unittest from synapse.handlers.register import RegistrationHandler +from synapse.types import UserID from tests.utils import setup_test_homeserver @@ -36,25 +37,21 @@ class RegistrationTestCase(unittest.TestCase): self.mock_distributor = Mock() self.mock_distributor.declare("registered_user") self.mock_captcha_client = Mock() - hs = yield setup_test_homeserver( + self.hs = yield setup_test_homeserver( handlers=None, http_client=None, expire_access_token=True) - hs.handlers = RegistrationHandlers(hs) - self.handler = hs.get_handlers().registration_handler - hs.get_handlers().profile_handler = Mock() + self.hs.handlers = RegistrationHandlers(self.hs) + self.handler = self.hs.get_handlers().registration_handler + self.hs.get_handlers().profile_handler = Mock() self.mock_handler = Mock(spec=[ "generate_short_term_login_token", ]) - hs.get_handlers().auth_handler = self.mock_handler + self.hs.get_handlers().auth_handler = self.mock_handler @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): - """ - Returns: - The user doess not exist in this case so it will register and log it in - """ duration_ms = 200 local_part = "someone" display_name = "someone" @@ -65,3 +62,22 @@ class RegistrationTestCase(unittest.TestCase): local_part, display_name, duration_ms) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') + + @defer.inlineCallbacks + def test_if_user_exists(self): + store = self.hs.get_datastore() + frank = UserID.from_string("@frank:test") + yield store.register( + user_id=frank.to_string(), + token="jkv;g498752-43gj['eamb!-5", + password_hash=None) + duration_ms = 200 + local_part = "frank" + display_name = "Frank" + user_id = "@frank:test" + mock_token = self.mock_handler.generate_short_term_login_token + mock_token.return_value = 'secret' + result_user_id, result_token = yield self.handler.get_or_create_user( + local_part, display_name, duration_ms) + self.assertEquals(result_user_id, user_id) + self.assertEquals(result_token, 'secret') -- cgit 1.5.1 From c626fc576a87f401c594cbb070d5b0000e45b4e1 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 31 May 2016 13:53:48 +0100 Subject: Move the AS handler out of the Handlers object. Access it directly from the homeserver itself. It already wasn't inheriting from BaseHandler storing it on the Handlers object was already somewhat dubious. --- synapse/appservice/scheduler.py | 14 +++++++------- synapse/handlers/__init__.py | 11 ----------- synapse/handlers/appservice.py | 15 +++++---------- synapse/handlers/directory.py | 3 ++- synapse/notifier.py | 10 ++++------ synapse/server.py | 15 +++++++++++++++ tests/handlers/test_appservice.py | 6 +++--- 7 files changed, 36 insertions(+), 38 deletions(-) (limited to 'tests') diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 47a4e9f864..9afc8fd754 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -56,22 +56,22 @@ import logging logger = logging.getLogger(__name__) -class AppServiceScheduler(object): +class ApplicationServiceScheduler(object): """ Public facing API for this module. Does the required DI to tie the components together. This also serves as the "event_pool", which in this case is a simple array. """ - def __init__(self, clock, store, as_api): - self.clock = clock - self.store = store - self.as_api = as_api + def __init__(self, hs): + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.as_api = hs.get_application_service_api() def create_recoverer(service, callback): - return _Recoverer(clock, store, as_api, service, callback) + return _Recoverer(self.clock, self.store, self.as_api, service, callback) self.txn_ctrl = _TransactionController( - clock, store, as_api, create_recoverer + self.clock, self.store, self.as_api, create_recoverer ) self.queuer = _ServiceQueuer(self.txn_ctrl) diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 0ac5d3da3a..c0069e23d6 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.appservice.scheduler import AppServiceScheduler -from synapse.appservice.api import ApplicationServiceApi from .register import RegistrationHandler from .room import ( RoomCreationHandler, RoomContextHandler, @@ -26,7 +24,6 @@ from .federation import FederationHandler from .profile import ProfileHandler from .directory import DirectoryHandler from .admin import AdminHandler -from .appservice import ApplicationServicesHandler from .auth import AuthHandler from .identity import IdentityHandler from .receipts import ReceiptsHandler @@ -53,14 +50,6 @@ class Handlers(object): self.directory_handler = DirectoryHandler(hs) self.admin_handler = AdminHandler(hs) self.receipts_handler = ReceiptsHandler(hs) - asapi = ApplicationServiceApi(hs) - self.appservice_handler = ApplicationServicesHandler( - hs, asapi, AppServiceScheduler( - clock=hs.get_clock(), - store=hs.get_datastore(), - as_api=asapi - ) - ) self.auth_handler = AuthHandler(hs) self.identity_handler = IdentityHandler(hs) self.search_handler = SearchHandler(hs) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 75fc74c797..051ccdb380 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -17,7 +17,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.appservice import ApplicationService -from synapse.types import UserID import logging @@ -35,16 +34,13 @@ def log_failure(failure): ) -# NB: Purposefully not inheriting BaseHandler since that contains way too much -# setup code which this handler does not need or use. This makes testing a lot -# easier. class ApplicationServicesHandler(object): - def __init__(self, hs, appservice_api, appservice_scheduler): + def __init__(self, hs): self.store = hs.get_datastore() - self.hs = hs - self.appservice_api = appservice_api - self.scheduler = appservice_scheduler + self.is_mine_id = hs.is_mine_id + self.appservice_api = hs.get_application_service_api() + self.scheduler = hs.get_application_service_scheduler() self.started_scheduler = False @defer.inlineCallbacks @@ -169,8 +165,7 @@ class ApplicationServicesHandler(object): @defer.inlineCallbacks def _is_unknown_user(self, user_id): - user = UserID.from_string(user_id) - if not self.hs.is_mine(user): + if not self.is_mine_id(user_id): # we don't know if they are unknown or not since it isn't one of our # users. We can't poke ASes. defer.returnValue(False) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 8eeb225811..4bea7f2b19 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -33,6 +33,7 @@ class DirectoryHandler(BaseHandler): super(DirectoryHandler, self).__init__(hs) self.state = hs.get_state_handler() + self.appservice_handler = hs.get_application_service_handler() self.federation = hs.get_replication_layer() self.federation.register_query_handler( @@ -281,7 +282,7 @@ class DirectoryHandler(BaseHandler): ) if not result: # Query AS to see if it exists - as_handler = self.hs.get_handlers().appservice_handler + as_handler = self.appservice_handler result = yield as_handler.query_room_alias_exists(room_alias) defer.returnValue(result) diff --git a/synapse/notifier.py b/synapse/notifier.py index 33b79c0ec7..cbec4d30ae 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -140,8 +140,6 @@ class Notifier(object): UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 def __init__(self, hs): - self.hs = hs - self.user_to_user_stream = {} self.room_to_user_streams = {} self.appservice_to_user_streams = {} @@ -151,6 +149,8 @@ class Notifier(object): self.pending_new_room_events = [] self.clock = hs.get_clock() + self.appservice_handler = hs.get_application_service_handler() + self.state_handler = hs.get_state_handler() hs.get_distributor().observe( "user_joined_room", self._user_joined_room @@ -232,9 +232,7 @@ class Notifier(object): def _on_new_room_event(self, event, room_stream_id, extra_users=[]): """Notify any user streams that are interested in this room event""" # poke any interested application service. - self.hs.get_handlers().appservice_handler.notify_interested_services( - event - ) + self.appservice_handler.notify_interested_services(event) app_streams = set() @@ -449,7 +447,7 @@ class Notifier(object): @defer.inlineCallbacks def _is_world_readable(self, room_id): - state = yield self.hs.get_state_handler().get_current_state( + state = yield self.state_handler.get_current_state( room_id, EventTypes.RoomHistoryVisibility ) diff --git a/synapse/server.py b/synapse/server.py index bfd5608b7d..7cf22b1eea 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -22,6 +22,8 @@ from twisted.web.client import BrowserLikePolicyForHTTPS from twisted.enterprise import adbapi +from synapse.appservice.scheduler import ApplicationServiceScheduler +from synapse.appservice.api import ApplicationServiceApi from synapse.federation import initialize_http_replication from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.notifier import Notifier @@ -31,6 +33,7 @@ from synapse.handlers.presence import PresenceHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler from synapse.handlers.room import RoomListHandler +from synapse.handlers.appservice import ApplicationServicesHandler from synapse.state import StateHandler from synapse.storage import DataStore from synapse.util import Clock @@ -86,6 +89,9 @@ class HomeServer(object): 'sync_handler', 'typing_handler', 'room_list_handler', + 'application_service_api', + 'application_service_scheduler', + 'application_service_handler', 'notifier', 'distributor', 'client_resource', @@ -184,6 +190,15 @@ class HomeServer(object): def build_room_list_handler(self): return RoomListHandler(self) + def build_application_service_api(self): + return ApplicationServiceApi(self) + + def build_application_service_scheduler(self): + return ApplicationServiceScheduler(self) + + def build_application_service_handler(self): + return ApplicationServicesHandler(self) + def build_event_sources(self): return EventSources(self) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 7ddbbb9b4a..a884c95f8d 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -30,9 +30,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_scheduler = Mock() hs = Mock() hs.get_datastore = Mock(return_value=self.mock_store) - self.handler = ApplicationServicesHandler( - hs, self.mock_as_api, self.mock_scheduler - ) + hs.get_application_service_api = Mock(return_value=self.mock_as_api) + hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler) + self.handler = ApplicationServicesHandler(hs) @defer.inlineCallbacks def test_notify_interested_services(self): -- cgit 1.5.1 From 195254cae80f4748c3fc0ac3b46000047c2e6cc0 Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 1 Jun 2016 11:14:16 +0100 Subject: Inject fake room list handler in tests Otherwise it tries to start the remote public room list updating looping call which breaks. --- tests/utils.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'tests') diff --git a/tests/utils.py b/tests/utils.py index 59d985b5f2..006abedbc1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -67,6 +67,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): version_string="Synapse/tests", database_engine=create_engine(config.database_config), get_db_conn=db_pool.get_db_conn, + room_list_handler=object(), **kargs ) hs.setup() @@ -75,6 +76,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): name, db_pool=None, datastore=datastore, config=config, version_string="Synapse/tests", database_engine=create_engine(config.database_config), + room_list_handler=object(), **kargs ) -- cgit 1.5.1 From 4a10510cd5aff790127a185ecefc83b881a717cc Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 2 Jun 2016 13:31:45 +0100 Subject: Split out the auth handler --- synapse/handlers/__init__.py | 2 -- synapse/handlers/register.py | 2 +- synapse/rest/client/v1/login.py | 11 ++++++----- synapse/rest/client/v2_alpha/account.py | 4 ++-- synapse/rest/client/v2_alpha/auth.py | 2 +- synapse/rest/client/v2_alpha/register.py | 2 +- synapse/rest/client/v2_alpha/tokenrefresh.py | 2 +- synapse/server.py | 5 +++++ tests/rest/client/v2_alpha/test_register.py | 2 +- tests/utils.py | 15 +++++---------- 10 files changed, 23 insertions(+), 24 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index c0069e23d6..d28e07f0d9 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -24,7 +24,6 @@ from .federation import FederationHandler from .profile import ProfileHandler from .directory import DirectoryHandler from .admin import AdminHandler -from .auth import AuthHandler from .identity import IdentityHandler from .receipts import ReceiptsHandler from .search import SearchHandler @@ -50,7 +49,6 @@ class Handlers(object): self.directory_handler = DirectoryHandler(hs) self.admin_handler = AdminHandler(hs) self.receipts_handler = ReceiptsHandler(hs) - self.auth_handler = AuthHandler(hs) self.identity_handler = IdentityHandler(hs) self.search_handler = SearchHandler(hs) self.room_context_handler = RoomContextHandler(hs) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 16f33f8371..bbc07b045e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -413,7 +413,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue((user_id, token)) def auth_handler(self): - return self.hs.get_handlers().auth_handler + return self.hs.get_auth_handler() @defer.inlineCallbacks def guest_access_token_for(self, medium, address, inviter_user_id): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 3b5544851b..8df9d10efa 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -58,6 +58,7 @@ class LoginRestServlet(ClientV1RestServlet): self.cas_required_attributes = hs.config.cas_required_attributes self.servername = hs.config.server_name self.http_client = hs.get_simple_http_client() + self.auth_handler = self.hs.get_auth_handler() def on_GET(self, request): flows = [] @@ -143,7 +144,7 @@ class LoginRestServlet(ClientV1RestServlet): user_id, self.hs.hostname ).to_string() - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_id, access_token, refresh_token = yield auth_handler.login_with_password( user_id=user_id, password=login_submission["password"]) @@ -160,7 +161,7 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def do_token_login(self, login_submission): token = login_submission['token'] - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) @@ -194,7 +195,7 @@ class LoginRestServlet(ClientV1RestServlet): raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_exists = yield auth_handler.does_user_exist(user_id) if user_exists: user_id, access_token, refresh_token = ( @@ -243,7 +244,7 @@ class LoginRestServlet(ClientV1RestServlet): raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_exists = yield auth_handler.does_user_exist(user_id) if user_exists: user_id, access_token, refresh_token = ( @@ -412,7 +413,7 @@ class CasTicketServlet(ClientV1RestServlet): raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_exists = yield auth_handler.does_user_exist(user_id) if not user_exists: user_id, _ = ( diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index c88c270537..9a84873a5f 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -35,7 +35,7 @@ class PasswordRestServlet(RestServlet): super(PasswordRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() @defer.inlineCallbacks def on_POST(self, request): @@ -97,7 +97,7 @@ class ThreepidRestServlet(RestServlet): self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 78181b7b18..58d3cad6a1 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet): super(AuthRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler @defer.inlineCallbacks diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 1ecc02d94d..2088c316d1 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -49,7 +49,7 @@ class RegisterRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index a158c2209a..8270e8787f 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -38,7 +38,7 @@ class TokenRefreshRestServlet(RestServlet): body = parse_json_object_from_request(request) try: old_refresh_token = body["refresh_token"] - auth_handler = self.hs.get_handlers().auth_handler + auth_handler = self.hs.get_auth_handler() (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( old_refresh_token, auth_handler.generate_refresh_token) new_access_token = yield auth_handler.issue_access_token(user_id) diff --git a/synapse/server.py b/synapse/server.py index 7cf22b1eea..dd4b81c658 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -33,6 +33,7 @@ from synapse.handlers.presence import PresenceHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler from synapse.handlers.room import RoomListHandler +from synapse.handlers.auth import AuthHandler from synapse.handlers.appservice import ApplicationServicesHandler from synapse.state import StateHandler from synapse.storage import DataStore @@ -89,6 +90,7 @@ class HomeServer(object): 'sync_handler', 'typing_handler', 'room_list_handler', + 'auth_handler', 'application_service_api', 'application_service_scheduler', 'application_service_handler', @@ -190,6 +192,9 @@ class HomeServer(object): def build_room_list_handler(self): return RoomListHandler(self) + def build_auth_handler(self): + return AuthHandler(self) + def build_application_service_api(self): return ApplicationServiceApi(self) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index affd42c015..cda0a2b27c 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -33,7 +33,6 @@ class RegisterRestServletTestCase(unittest.TestCase): # do the dance to hook it up to the hs global self.handlers = Mock( - auth_handler=self.auth_handler, registration_handler=self.registration_handler, identity_handler=self.identity_handler, login_handler=self.login_handler @@ -42,6 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.hs.hostname = "superbig~testing~thing.com" self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) + self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.config.enable_registration = True # init the thing we're testing diff --git a/tests/utils.py b/tests/utils.py index 006abedbc1..e19ae581e0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -81,16 +81,11 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): ) # bcrypt is far too slow to be doing in unit tests - def swap_out_hash_for_testing(old_build_handlers): - def build_handlers(): - handlers = old_build_handlers() - auth_handler = handlers.auth_handler - auth_handler.hash = lambda p: hashlib.md5(p).hexdigest() - auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h - return handlers - return build_handlers - - hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers) + # Need to let the HS build an auth handler and then mess with it + # because AuthHandler's constructor requires the HS, so we can't make one + # beforehand and pass it in to the HS's constructor (chicken / egg) + hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest() + hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h fed = kargs.get("resource_for_federation", None) if fed: -- cgit 1.5.1 From 70599ce9252997d32d0bf9f26a4e02c99bbe474d Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 2 Jun 2016 15:20:15 +0100 Subject: Allow external processes to mark a user as syncing. (#812) * Add infrastructure to the presence handler to track sync requests in external processes * Expire stale entries for dead external processes * Add an http endpoint for making users as syncing Add some docstrings and comments. * Fixes --- synapse/handlers/presence.py | 119 +++++++++++++++++++++++++++---- synapse/replication/presence_resource.py | 59 +++++++++++++++ synapse/replication/resource.py | 2 + tests/handlers/test_presence.py | 16 ++--- 4 files changed, 174 insertions(+), 22 deletions(-) create mode 100644 synapse/replication/presence_resource.py (limited to 'tests') diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 37f57301fb..fc8538b41e 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -68,6 +68,10 @@ FEDERATION_TIMEOUT = 30 * 60 * 1000 # How often to resend presence to remote servers FEDERATION_PING_INTERVAL = 25 * 60 * 1000 +# How long we will wait before assuming that the syncs from an external process +# are dead. +EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000 + assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER @@ -158,10 +162,21 @@ class PresenceHandler(object): self.serial_to_user = {} self._next_serial = 1 - # Keeps track of the number of *ongoing* syncs. While this is non zero - # a user will never go offline. + # Keeps track of the number of *ongoing* syncs on this process. While + # this is non zero a user will never go offline. self.user_to_num_current_syncs = {} + # Keeps track of the number of *ongoing* syncs on other processes. + # While any sync is ongoing on another process the user will never + # go offline. + # Each process has a unique identifier and an update frequency. If + # no update is received from that process within the update period then + # we assume that all the sync requests on that process have stopped. + # Stored as a dict from process_id to set of user_id, and a dict of + # process_id to millisecond timestamp last updated. + self.external_process_to_current_syncs = {} + self.external_process_last_updated_ms = {} + # Start a LoopingCall in 30s that fires every 5s. # The initial delay is to allow disconnected clients a chance to # reconnect before we treat them as offline. @@ -272,13 +287,26 @@ class PresenceHandler(object): # Fetch the list of users that *may* have timed out. Things may have # changed since the timeout was set, so we won't necessarily have to # take any action. - users_to_check = self.wheel_timer.fetch(now) + users_to_check = set(self.wheel_timer.fetch(now)) + + # Check whether the lists of syncing processes from an external + # process have expired. + expired_process_ids = [ + process_id for process_id, last_update + in self.external_process_last_update.items() + if now - last_update > EXTERNAL_PROCESS_EXPIRY + ] + for process_id in expired_process_ids: + users_to_check.update( + self.external_process_to_current_syncs.pop(process_id, ()) + ) + self.external_process_last_update.pop(process_id) states = [ self.user_to_current_state.get( user_id, UserPresenceState.default(user_id) ) - for user_id in set(users_to_check) + for user_id in users_to_check ] timers_fired_counter.inc_by(len(states)) @@ -286,7 +314,7 @@ class PresenceHandler(object): changes = handle_timeouts( states, is_mine_fn=self.is_mine_id, - user_to_num_current_syncs=self.user_to_num_current_syncs, + syncing_users=self.get_syncing_users(), now=now, ) @@ -363,6 +391,73 @@ class PresenceHandler(object): defer.returnValue(_user_syncing()) + def get_currently_syncing_users(self): + """Get the set of user ids that are currently syncing on this HS. + Returns: + set(str): A set of user_id strings. + """ + syncing_user_ids = { + user_id for user_id, count in self.user_to_num_current_syncs.items() + if count + } + syncing_user_ids.update(self.external_process_to_current_syncs.values()) + return syncing_user_ids + + @defer.inlineCallbacks + def update_external_syncs(self, process_id, syncing_user_ids): + """Update the syncing users for an external process + + Args: + process_id(str): An identifier for the process the users are + syncing against. This allows synapse to process updates + as user start and stop syncing against a given process. + syncing_user_ids(set(str)): The set of user_ids that are + currently syncing on that server. + """ + + # Grab the previous list of user_ids that were syncing on that process + prev_syncing_user_ids = ( + self.external_process_to_current_syncs.get(process_id, set()) + ) + # Grab the current presence state for both the users that are syncing + # now and the users that were syncing before this update. + prev_states = yield self.current_state_for_users( + syncing_user_ids | prev_syncing_user_ids + ) + updates = [] + time_now_ms = self.clock.time_msec() + + # For each new user that is syncing check if we need to mark them as + # being online. + for new_user_id in syncing_user_ids - prev_syncing_user_ids: + prev_state = prev_states[new_user_id] + if prev_state.state == PresenceState.OFFLINE: + updates.append(prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=time_now_ms, + last_user_sync_ts=time_now_ms, + )) + else: + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=time_now_ms, + )) + + # For each user that is still syncing or stopped syncing update the + # last sync time so that we will correctly apply the grace period when + # they stop syncing. + for old_user_id in prev_syncing_user_ids: + prev_state = prev_states[old_user_id] + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=time_now_ms, + )) + + yield self._update_states(updates) + + # Update the last updated time for the process. We expire the entries + # if we don't receive an update in the given timeframe. + self.external_process_last_updated_ms[process_id] = self.clock.time_msec() + self.external_process_to_current_syncs[process_id] = syncing_user_ids + @defer.inlineCallbacks def current_state_for_user(self, user_id): """Get the current presence state for a user. @@ -935,15 +1030,14 @@ class PresenceEventSource(object): return self.get_new_events(user, from_key=None, include_offline=False) -def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now): +def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): """Checks the presence of users that have timed out and updates as appropriate. Args: user_states(list): List of UserPresenceState's to check. is_mine_fn (fn): Function that returns if a user_id is ours - user_to_num_current_syncs (dict): Mapping of user_id to number of currently - active syncs. + syncing_user_ids (set): Set of user_ids with active syncs. now (int): Current time in ms. Returns: @@ -954,21 +1048,20 @@ def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now): for state in user_states: is_mine = is_mine_fn(state.user_id) - new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now) + new_state = handle_timeout(state, is_mine, syncing_user_ids, now) if new_state: changes[state.user_id] = new_state return changes.values() -def handle_timeout(state, is_mine, user_to_num_current_syncs, now): +def handle_timeout(state, is_mine, syncing_user_ids, now): """Checks the presence of the user to see if any of the timers have elapsed Args: state (UserPresenceState) is_mine (bool): Whether the user is ours - user_to_num_current_syncs (dict): Mapping of user_id to number of currently - active syncs. + syncing_user_ids (set): Set of user_ids with active syncs. now (int): Current time in ms. Returns: @@ -1002,7 +1095,7 @@ def handle_timeout(state, is_mine, user_to_num_current_syncs, now): # If there are have been no sync for a while (and none ongoing), # set presence to offline - if not user_to_num_current_syncs.get(user_id, 0): + if user_id not in syncing_user_ids: if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: state = state.copy_and_replace( state=PresenceState.OFFLINE, diff --git a/synapse/replication/presence_resource.py b/synapse/replication/presence_resource.py new file mode 100644 index 0000000000..fc18130ab4 --- /dev/null +++ b/synapse/replication/presence_resource.py @@ -0,0 +1,59 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import respond_with_json_bytes, request_handler +from synapse.http.servlet import parse_json_object_from_request + +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET +from twisted.internet import defer + + +class PresenceResource(Resource): + """ + HTTP endpoint for marking users as syncing. + + POST /_synapse/replication/presence HTTP/1.1 + Content-Type: application/json + + { + "process_id": "", + "syncing_users": [""] + } + """ + + def __init__(self, hs): + Resource.__init__(self) # Resource is old-style, so no super() + + self.version_string = hs.version_string + self.clock = hs.get_clock() + self.presence_handler = hs.get_presence_handler() + + def render_POST(self, request): + self._async_render_POST(request) + return NOT_DONE_YET + + @request_handler() + @defer.inlineCallbacks + def _async_render_POST(self, request): + content = parse_json_object_from_request(request) + + process_id = content["process_id"] + syncing_user_ids = content["syncing_users"] + + yield self.presence_handler.update_external_syncs( + process_id, set(syncing_user_ids) + ) + + respond_with_json_bytes(request, 200, "{}") diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 847f212a3d..8c2d487ff4 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -16,6 +16,7 @@ from synapse.http.servlet import parse_integer, parse_string from synapse.http.server import request_handler, finish_request from synapse.replication.pusher_resource import PusherResource +from synapse.replication.presence_resource import PresenceResource from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET @@ -115,6 +116,7 @@ class ReplicationResource(Resource): self.clock = hs.get_clock() self.putChild("remove_pushers", PusherResource(hs)) + self.putChild("syncing_users", PresenceResource(hs)) def render_GET(self, request): self._async_render_GET(request) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 87c795fcfa..b531ba8540 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -264,7 +264,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, user_to_num_current_syncs={}, now=now + state, is_mine=True, syncing_user_ids=set(), now=now ) self.assertIsNotNone(new_state) @@ -282,7 +282,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, user_to_num_current_syncs={}, now=now + state, is_mine=True, syncing_user_ids=set(), now=now ) self.assertIsNotNone(new_state) @@ -300,9 +300,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, user_to_num_current_syncs={ - user_id: 1, - }, now=now + state, is_mine=True, syncing_user_ids=set([user_id]), now=now ) self.assertIsNotNone(new_state) @@ -321,7 +319,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, user_to_num_current_syncs={}, now=now + state, is_mine=True, syncing_user_ids=set(), now=now ) self.assertIsNotNone(new_state) @@ -340,7 +338,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, user_to_num_current_syncs={}, now=now + state, is_mine=True, syncing_user_ids=set(), now=now ) self.assertIsNone(new_state) @@ -358,7 +356,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=False, user_to_num_current_syncs={}, now=now + state, is_mine=False, syncing_user_ids=set(), now=now ) self.assertIsNotNone(new_state) @@ -377,7 +375,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, user_to_num_current_syncs={}, now=now + state, is_mine=True, syncing_user_ids=set(), now=now ) self.assertIsNotNone(new_state) -- cgit 1.5.1 From 56d15a05306896555ded3025f8a808bda04872fa Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 2 Jun 2016 16:28:54 +0100 Subject: Store the typing users as user_id strings. (#819) Rather than storing them as UserID objects. --- synapse/handlers/typing.py | 64 ++++++++++++++++++++++++------------------- tests/handlers/test_typing.py | 4 +-- 2 files changed, 38 insertions(+), 30 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index d46f05f426..3c54307bed 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) # A tiny object useful for storing a user's membership in a room, as a mapping # key -RoomMember = namedtuple("RoomMember", ("room_id", "user")) +RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) class TypingHandler(object): @@ -38,7 +38,7 @@ class TypingHandler(object): self.store = hs.get_datastore() self.server_name = hs.config.server_name self.auth = hs.get_auth() - self.is_mine = hs.is_mine + self.is_mine_id = hs.is_mine_id self.notifier = hs.get_notifier() self.clock = hs.get_clock() @@ -67,20 +67,23 @@ class TypingHandler(object): @defer.inlineCallbacks def started_typing(self, target_user, auth_user, room_id, timeout): - if not self.is_mine(target_user): + target_user_id = target_user.to_string() + auth_user_id = auth_user.to_string() + + if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_joined_room(room_id, target_user.to_string()) + yield self.auth.check_joined_room(room_id, target_user_id) logger.debug( - "%s has started typing in %s", target_user.to_string(), room_id + "%s has started typing in %s", target_user_id, room_id ) until = self.clock.time_msec() + timeout - member = RoomMember(room_id=room_id, user=target_user) + member = RoomMember(room_id=room_id, user_id=target_user_id) was_present = member in self._member_typing_until @@ -104,25 +107,28 @@ class TypingHandler(object): yield self._push_update( room_id=room_id, - user=target_user, + user_id=target_user_id, typing=True, ) @defer.inlineCallbacks def stopped_typing(self, target_user, auth_user, room_id): - if not self.is_mine(target_user): + target_user_id = target_user.to_string() + auth_user_id = auth_user.to_string() + + if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_joined_room(room_id, target_user.to_string()) + yield self.auth.check_joined_room(room_id, target_user_id) logger.debug( - "%s has stopped typing in %s", target_user.to_string(), room_id + "%s has stopped typing in %s", target_user_id, room_id ) - member = RoomMember(room_id=room_id, user=target_user) + member = RoomMember(room_id=room_id, user_id=target_user_id) if member in self._member_typing_timer: self.clock.cancel_call_later(self._member_typing_timer[member]) @@ -132,8 +138,9 @@ class TypingHandler(object): @defer.inlineCallbacks def user_left_room(self, user, room_id): - if self.is_mine(user): - member = RoomMember(room_id=room_id, user=user) + user_id = user.to_string() + if self.is_mine_id(user_id): + member = RoomMember(room_id=room_id, user=user_id) yield self._stopped_typing(member) @defer.inlineCallbacks @@ -144,7 +151,7 @@ class TypingHandler(object): yield self._push_update( room_id=member.room_id, - user=member.user, + user_id=member.user_id, typing=False, ) @@ -156,7 +163,7 @@ class TypingHandler(object): del self._member_typing_timer[member] @defer.inlineCallbacks - def _push_update(self, room_id, user, typing): + def _push_update(self, room_id, user_id, typing): domains = yield self.store.get_joined_hosts_for_room(room_id) deferreds = [] @@ -164,7 +171,7 @@ class TypingHandler(object): if domain == self.server_name: self._push_update_local( room_id=room_id, - user=user, + user_id=user_id, typing=typing ) else: @@ -173,7 +180,7 @@ class TypingHandler(object): edu_type="m.typing", content={ "room_id": room_id, - "user_id": user.to_string(), + "user_id": user_id, "typing": typing, }, )) @@ -183,23 +190,26 @@ class TypingHandler(object): @defer.inlineCallbacks def _recv_edu(self, origin, content): room_id = content["room_id"] - user = UserID.from_string(content["user_id"]) + user_id = content["user_id"] + + # Check that the string is a valid user id + UserID.from_string(user_id) domains = yield self.store.get_joined_hosts_for_room(room_id) if self.server_name in domains: self._push_update_local( room_id=room_id, - user=user, + user_id=user_id, typing=content["typing"] ) - def _push_update_local(self, room_id, user, typing): + def _push_update_local(self, room_id, user_id, typing): room_set = self._room_typing.setdefault(room_id, set()) if typing: - room_set.add(user) + room_set.add(user_id) else: - room_set.discard(user) + room_set.discard(user_id) self._latest_room_serial += 1 self._room_serials[room_id] = self._latest_room_serial @@ -215,9 +225,7 @@ class TypingHandler(object): for room_id, serial in self._room_serials.items(): if last_id < serial and serial <= current_id: typing = self._room_typing[room_id] - typing_bytes = json.dumps([ - u.to_string() for u in typing - ], ensure_ascii=False) + typing_bytes = json.dumps(list(typing), ensure_ascii=False) rows.append((serial, room_id, typing_bytes)) rows.sort() return rows @@ -239,7 +247,7 @@ class TypingNotificationEventSource(object): "type": "m.typing", "room_id": room_id, "content": { - "user_ids": [u.to_string() for u in typing], + "user_ids": list(typing), }, } diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index abb739ae52..ab9899b7d5 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -251,12 +251,12 @@ class TypingNotificationsTestCase(unittest.TestCase): # Gut-wrenching from synapse.handlers.typing import RoomMember - member = RoomMember(self.room_id, self.u_apple) + member = RoomMember(self.room_id, self.u_apple.to_string()) self.handler._member_typing_until[member] = 1002000 self.handler._member_typing_timer[member] = ( self.clock.call_later(1002, lambda: 0) ) - self.handler._room_typing[self.room_id] = set((self.u_apple,)) + self.handler._room_typing[self.room_id] = set((self.u_apple.to_string(),)) self.assertEquals(self.event_source.get_current_key(), 0) -- cgit 1.5.1 From 73c711243382a48b9b67fddf5ed9df2d1ee1be43 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 2 Jun 2016 11:29:44 +0100 Subject: Change CacheMetrics to be quicker We change it so that each cache has an individual CacheMetric, instead of having one global CacheMetric. This means that when a cache tries to increment a counter it does not need to go through so many indirections. --- synapse/metrics/__init__.py | 16 ++++------- synapse/metrics/metric.py | 44 +++++++++++++++--------------- synapse/util/caches/__init__.py | 20 ++++++++++---- synapse/util/caches/descriptors.py | 17 +++++++++--- synapse/util/caches/dictionary_cache.py | 8 +++--- synapse/util/caches/expiringcache.py | 8 +++--- synapse/util/caches/stream_change_cache.py | 16 +++++------ tests/metrics/test_metric.py | 23 +++++++--------- 8 files changed, 82 insertions(+), 70 deletions(-) (limited to 'tests') diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 5664d5a381..c38f24485a 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -33,11 +33,7 @@ from .metric import ( logger = logging.getLogger(__name__) -# We'll keep all the available metrics in a single toplevel dict, one shared -# for the entire process. We don't currently support per-HomeServer instances -# of metrics, because in practice any one python VM will host only one -# HomeServer anyway. This makes a lot of implementation neater -all_metrics = {} +all_metrics = [] class Metrics(object): @@ -53,7 +49,7 @@ class Metrics(object): metric = metric_class(full_name, *args, **kwargs) - all_metrics[full_name] = metric + all_metrics.append(metric) return metric def register_counter(self, *args, **kwargs): @@ -84,12 +80,12 @@ def render_all(): # TODO(paul): Internal hack update_resource_metrics() - for name in sorted(all_metrics.keys()): + for metric in all_metrics: try: - strs += all_metrics[name].render() + strs += metric.render() except Exception: - strs += ["# FAILED to render %s" % name] - logger.exception("Failed to render %s metric", name) + strs += ["# FAILED to render"] + logger.exception("Failed to render metric") strs.append("") # to generate a final CRLF diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py index 368fc24984..341043952a 100644 --- a/synapse/metrics/metric.py +++ b/synapse/metrics/metric.py @@ -47,9 +47,6 @@ class BaseMetric(object): for k, v in zip(self.labels, values)]) ) - def render(self): - return map_concat(self.render_item, sorted(self.counts.keys())) - class CounterMetric(BaseMetric): """The simplest kind of metric; one that stores a monotonically-increasing @@ -83,6 +80,9 @@ class CounterMetric(BaseMetric): def render_item(self, k): return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])] + def render(self): + return map_concat(self.render_item, sorted(self.counts.keys())) + class CallbackMetric(BaseMetric): """A metric that returns the numeric value returned by a callback whenever @@ -126,30 +126,30 @@ class DistributionMetric(object): class CacheMetric(object): - """A combination of two CounterMetrics, one to count cache hits and one to - count a total, and a callback metric to yield the current size. + __slots__ = ("name", "cache_name", "hits", "misses", "size_callback") - This metric generates standard metric name pairs, so that monitoring rules - can easily be applied to measure hit ratio.""" - - def __init__(self, name, size_callback, labels=[]): + def __init__(self, name, size_callback, cache_name): self.name = name + self.cache_name = cache_name - self.hits = CounterMetric(name + ":hits", labels=labels) - self.total = CounterMetric(name + ":total", labels=labels) + self.hits = 0 + self.misses = 0 - self.size = CallbackMetric( - name + ":size", - callback=size_callback, - labels=labels, - ) + self.size_callback = size_callback - def inc_hits(self, *values): - self.hits.inc(*values) - self.total.inc(*values) + def inc_hits(self): + self.hits += 1 - def inc_misses(self, *values): - self.total.inc(*values) + def inc_misses(self): + self.misses += 1 def render(self): - return self.hits.render() + self.total.render() + self.size.render() + size = self.size_callback() + hits = self.hits + total = self.misses + self.hits + + return [ + """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits), + """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total), + """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size), + ] diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index d53569ca49..ebd715c5dc 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -24,11 +24,21 @@ DEBUG_CACHES = False metrics = synapse.metrics.get_metrics_for("synapse.util.caches") caches_by_name = {} -cache_counter = metrics.register_cache( - "cache", - lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, - labels=["name"], -) +# cache_counter = metrics.register_cache( +# "cache", +# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, +# labels=["name"], +# ) + + +def register_cache(name, cache): + caches_by_name[name] = cache + return metrics.register_cache( + "cache", + lambda: len(cache), + name, + ) + _string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR)) caches_by_name["string_cache"] = _string_cache diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 758f5982b0..5d25c9e762 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -22,7 +22,7 @@ from synapse.util.logcontext import ( PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn ) -from . import caches_by_name, DEBUG_CACHES, cache_counter +from . import DEBUG_CACHES, register_cache from twisted.internet import defer @@ -43,6 +43,15 @@ CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) class Cache(object): + __slots__ = ( + "cache", + "max_entries", + "name", + "keylen", + "sequence", + "thread", + "metrics", + ) def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): if lru: @@ -59,7 +68,7 @@ class Cache(object): self.keylen = keylen self.sequence = 0 self.thread = None - caches_by_name[name] = self.cache + self.metrics = register_cache(name, self.cache) def check_thread(self): expected_thread = self.thread @@ -74,10 +83,10 @@ class Cache(object): def get(self, key, default=_CacheSentinel): val = self.cache.get(key, _CacheSentinel) if val is not _CacheSentinel: - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() return val - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() if default is _CacheSentinel: raise KeyError() diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index f92d80542b..b0ca1bb79d 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -15,7 +15,7 @@ from synapse.util.caches.lrucache import LruCache from collections import namedtuple -from . import caches_by_name, cache_counter +from . import register_cache import threading import logging @@ -43,7 +43,7 @@ class DictionaryCache(object): __slots__ = [] self.sentinel = Sentinel() - caches_by_name[name] = self.cache + self.metrics = register_cache(name, self.cache) def check_thread(self): expected_thread = self.thread @@ -58,7 +58,7 @@ class DictionaryCache(object): def get(self, key, dict_keys=None): entry = self.cache.get(key, self.sentinel) if entry is not self.sentinel: - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() if dict_keys is None: return DictionaryEntry(entry.full, dict(entry.value)) @@ -69,7 +69,7 @@ class DictionaryCache(object): if k in entry.value }) - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return DictionaryEntry(False, {}) def invalidate(self, key): diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 2b68c1ac93..080388958f 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches import cache_counter, caches_by_name +from synapse.util.caches import register_cache import logging @@ -49,7 +49,7 @@ class ExpiringCache(object): self._cache = {} - caches_by_name[cache_name] = self._cache + self.metrics = register_cache(cache_name, self._cache) def start(self): if not self._expiry_ms: @@ -78,9 +78,9 @@ class ExpiringCache(object): def __getitem__(self, key): try: entry = self._cache[key] - cache_counter.inc_hits(self._cache_name) + self.metrics.inc_hits() except KeyError: - cache_counter.inc_misses(self._cache_name) + self.metrics.inc_misses() raise if self._reset_expiry_on_get: diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index ea8a74ca69..3c051dabc4 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches import cache_counter, caches_by_name +from synapse.util.caches import register_cache from blist import sorteddict @@ -42,7 +42,7 @@ class StreamChangeCache(object): self._cache = sorteddict() self._earliest_known_stream_pos = current_stream_pos self.name = name - caches_by_name[self.name] = self._cache + self.metrics = register_cache(self.name, self._cache) for entity, stream_pos in prefilled_cache.items(): self.entity_has_changed(entity, stream_pos) @@ -53,19 +53,19 @@ class StreamChangeCache(object): assert type(stream_pos) is int if stream_pos < self._earliest_known_stream_pos: - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return True latest_entity_change_pos = self._entity_to_key.get(entity, None) if latest_entity_change_pos is None: - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() return False if stream_pos < latest_entity_change_pos: - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return True - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() return False def get_entities_changed(self, entities, stream_pos): @@ -82,10 +82,10 @@ class StreamChangeCache(object): self._cache[k] for k in keys[i:] ).intersection(entities) - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() else: result = entities - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return result diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index f3c1927ce1..f85455a5af 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -61,9 +61,6 @@ class CounterMetricTestCase(unittest.TestCase): 'vector{method="PUT"} 1', ]) - # Check that passing too few values errors - self.assertRaises(ValueError, counter.inc) - class CallbackMetricTestCase(unittest.TestCase): @@ -138,27 +135,27 @@ class CacheMetricTestCase(unittest.TestCase): def test_cache(self): d = dict() - metric = CacheMetric("cache", lambda: len(d)) + metric = CacheMetric("cache", lambda: len(d), "cache_name") self.assertEquals(metric.render(), [ - 'cache:hits 0', - 'cache:total 0', - 'cache:size 0', + 'cache:hits{name="cache_name"} 0', + 'cache:total{name="cache_name"} 0', + 'cache:size{name="cache_name"} 0', ]) metric.inc_misses() d["key"] = "value" self.assertEquals(metric.render(), [ - 'cache:hits 0', - 'cache:total 1', - 'cache:size 1', + 'cache:hits{name="cache_name"} 0', + 'cache:total{name="cache_name"} 1', + 'cache:size{name="cache_name"} 1', ]) metric.inc_hits() self.assertEquals(metric.render(), [ - 'cache:hits 1', - 'cache:total 2', - 'cache:size 1', + 'cache:hits{name="cache_name"} 1', + 'cache:total{name="cache_name"} 2', + 'cache:size{name="cache_name"} 1', ]) -- cgit 1.5.1 From 05e01f21d7012c1853ff566c8a76aa66087bfbd7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 3 Jun 2016 17:12:48 +0100 Subject: Remove event fetching from DB threads --- synapse/replication/slave/storage/events.py | 5 - synapse/storage/appservice.py | 21 +++-- synapse/storage/events.py | 138 ---------------------------- synapse/storage/room.py | 46 ++++++---- synapse/storage/search.py | 29 +++--- synapse/storage/stream.py | 34 +++---- tests/storage/test_appservice.py | 2 +- 7 files changed, 75 insertions(+), 200 deletions(-) (limited to 'tests') diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index cbc1ae4190..877c68508c 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -131,15 +131,10 @@ class SlavedEventStore(BaseSlavedStore): _get_events_from_cache = DataStore._get_events_from_cache.__func__ _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ - _parse_events_txn = DataStore._parse_events_txn.__func__ - _get_events_txn = DataStore._get_events_txn.__func__ - _get_event_txn = DataStore._get_event_txn.__func__ _enqueue_events = DataStore._enqueue_events.__func__ _do_fetch = DataStore._do_fetch.__func__ - _fetch_events_txn = DataStore._fetch_events_txn.__func__ _fetch_event_rows = DataStore._fetch_event_rows.__func__ _get_event_from_row = DataStore._get_event_from_row.__func__ - _get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__ _get_rooms_for_user_where_membership_is_txn = ( DataStore._get_rooms_for_user_where_membership_is_txn.__func__ ) diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index feb9d228ae..ffb7d4a25b 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -298,6 +298,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore): dict(txn_id=txn_id, as_id=service.id) ) + @defer.inlineCallbacks def get_oldest_unsent_txn(self, service): """Get the oldest transaction which has not been sent for this service. @@ -308,12 +309,23 @@ class ApplicationServiceTransactionStore(SQLBaseStore): A Deferred which resolves to an AppServiceTransaction or None. """ - return self.runInteraction( + entry = yield self.runInteraction( "get_oldest_unsent_appservice_txn", self._get_oldest_unsent_txn, service ) + if not entry: + defer.returnValue(None) + + event_ids = json.loads(entry["event_ids"]) + + events = yield self.get_events(event_ids) + + defer.returnValue(AppServiceTransaction( + service=service, id=entry["txn_id"], events=events + )) + def _get_oldest_unsent_txn(self, txn, service): # Monotonically increasing txn ids, so just select the smallest # one in the txns table (we delete them when they are sent) @@ -328,12 +340,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore): entry = rows[0] - event_ids = json.loads(entry["event_ids"]) - events = self._get_events_txn(txn, event_ids) - - return AppServiceTransaction( - service=service, id=entry["txn_id"], events=events - ) + return entry def _get_last_txn(self, txn, service_id): txn.execute( diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 2b3f79577b..b710505a7e 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -762,41 +762,6 @@ class EventsStore(SQLBaseStore): if e_id in event_map and event_map[e_id] ]) - def _get_events_txn(self, txn, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): - if not event_ids: - return [] - - event_map = self._get_events_from_cache( - event_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - missing_events_ids = [e for e in event_ids if e not in event_map] - - if not missing_events_ids: - return [ - event_map[e_id] for e_id in event_ids - if e_id in event_map and event_map[e_id] - ] - - missing_events = self._fetch_events_txn( - txn, - missing_events_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - event_map.update(missing_events) - - return [ - event_map[e_id] for e_id in event_ids - if e_id in event_map and event_map[e_id] - ] - def _invalidate_get_event_cache(self, event_id): for check_redacted in (False, True): for get_prev_content in (False, True): @@ -804,18 +769,6 @@ class EventsStore(SQLBaseStore): (event_id, check_redacted, get_prev_content) ) - def _get_event_txn(self, txn, event_id, check_redacted=True, - get_prev_content=False, allow_rejected=False): - - events = self._get_events_txn( - txn, [event_id], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - return events[0] if events else None - def _get_events_from_cache(self, events, check_redacted, get_prev_content, allow_rejected): event_map = {} @@ -981,34 +934,6 @@ class EventsStore(SQLBaseStore): return rows - def _fetch_events_txn(self, txn, events, check_redacted=True, - get_prev_content=False, allow_rejected=False): - if not events: - return {} - - rows = self._fetch_event_rows( - txn, events, - ) - - if not allow_rejected: - rows[:] = [r for r in rows if not r["rejects"]] - - res = [ - self._get_event_from_row_txn( - txn, - row["internal_metadata"], row["json"], row["redacts"], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - rejected_reason=row["rejects"], - ) - for row in rows - ] - - return { - r.event_id: r - for r in res - } - @defer.inlineCallbacks def _get_event_from_row(self, internal_metadata, js, redacted, check_redacted=True, get_prev_content=False, @@ -1070,69 +995,6 @@ class EventsStore(SQLBaseStore): defer.returnValue(ev) - def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, - check_redacted=True, get_prev_content=False, - rejected_reason=None): - d = json.loads(js) - internal_metadata = json.loads(internal_metadata) - - if rejected_reason: - rejected_reason = self._simple_select_one_onecol_txn( - txn, - table="rejections", - keyvalues={"event_id": rejected_reason}, - retcol="reason", - ) - - ev = FrozenEvent( - d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - if check_redacted and redacted: - ev = prune_event(ev) - - redaction_id = self._simple_select_one_onecol_txn( - txn, - table="redactions", - keyvalues={"redacts": ev.event_id}, - retcol="event_id", - ) - - ev.unsigned["redacted_by"] = redaction_id - # Get the redaction event. - - because = self._get_event_txn( - txn, - redaction_id, - check_redacted=False - ) - - if because: - ev.unsigned["redacted_because"] = because - - if get_prev_content and "replaces_state" in ev.unsigned: - prev = self._get_event_txn( - txn, - ev.unsigned["replaces_state"], - get_prev_content=False, - ) - if prev: - ev.unsigned["prev_content"] = prev.content - ev.unsigned["prev_sender"] = prev.sender - - self._get_event_cache.prefill( - (ev.event_id, check_redacted, get_prev_content), ev - ) - - return ev - - def _parse_events_txn(self, txn, rows): - event_ids = [r["event_id"] for r in rows] - - return self._get_events_txn(txn, event_ids) - @defer.inlineCallbacks def count_daily_messages(self): """ diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 26933e593a..97f9f1929c 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -194,32 +194,44 @@ class RoomStore(SQLBaseStore): @cachedInlineCallbacks() def get_room_name_and_aliases(self, room_id): - def f(txn): + def get_room_name(txn): sql = ( - "SELECT event_id FROM current_state_events " - "WHERE room_id = ? " + "SELECT name FROM room_names" + " INNER JOIN current_state_events USING (room_id, event_id)" + " WHERE room_id = ?" + " LIMIT 1" ) - sql += " AND ((type = 'm.room.name' AND state_key = '')" - sql += " OR type = 'm.room.aliases')" - txn.execute(sql, (room_id,)) - results = self.cursor_to_dict(txn) + rows = txn.fetchall() + if rows: + return rows[0][0] + else: + return None - return self._parse_events_txn(txn, results) + return [row[0] for row in txn.fetchall()] - events = yield self.runInteraction("get_room_name_and_aliases", f) + def get_room_aliases(txn): + sql = ( + "SELECT content FROM current_state_events" + " INNER JOIN events USING (room_id, event_id)" + " WHERE room_id = ?" + ) + txn.execute(sql, (room_id,)) + return [row[0] for row in txn.fetchall()] + + name = yield self.runInteraction("get_room_name", get_room_name) + alias_contents = yield self.runInteraction("get_room_aliases", get_room_aliases) - name = None aliases = [] - for e in events: - if e.type == 'm.room.name': - if 'name' in e.content: - name = e.content['name'] - elif e.type == 'm.room.aliases': - if 'aliases' in e.content: - aliases.extend(e.content['aliases']) + for c in alias_contents: + try: + content = json.loads(c) + except: + continue + + aliases.extend(content.get('aliases', [])) defer.returnValue((name, aliases)) diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 0224299625..12941d1775 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -21,6 +21,7 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine import logging import re +import ujson as json logger = logging.getLogger(__name__) @@ -52,7 +53,7 @@ class SearchStore(BackgroundUpdateStore): def reindex_search_txn(txn): sql = ( - "SELECT stream_ordering, event_id FROM events" + "SELECT stream_ordering, event_id, room_id, type, content FROM events" " WHERE ? <= stream_ordering AND stream_ordering < ?" " AND (%s)" " ORDER BY stream_ordering DESC" @@ -61,28 +62,30 @@ class SearchStore(BackgroundUpdateStore): txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - rows = txn.fetchall() + rows = self.cursor_to_dict(txn) if not rows: return 0 - min_stream_id = rows[-1][0] - event_ids = [row[1] for row in rows] - - events = self._get_events_txn(txn, event_ids) + min_stream_id = rows[-1]["stream_ordering"] event_search_rows = [] - for event in events: + for row in rows: try: - event_id = event.event_id - room_id = event.room_id - content = event.content - if event.type == "m.room.message": + event_id = row["event_id"] + room_id = row["room_id"] + etype = row["type"] + try: + content = json.loads(row["content"]) + except: + continue + + if etype == "m.room.message": key = "content.body" value = content["body"] - elif event.type == "m.room.topic": + elif etype == "m.room.topic": key = "content.topic" value = content["topic"] - elif event.type == "m.room.name": + elif etype == "m.room.name": key = "content.name" value = content["name"] except (KeyError, AttributeError): diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 95b12559a6..b9ad965fd6 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -132,29 +132,25 @@ class StreamStore(SQLBaseStore): return True return False - ret = self._get_events_txn( - txn, - # apply the filter on the room id list - [ - r["event_id"] for r in rows - if app_service_interested(r) - ], - get_prev_content=True - ) + return [r for r in rows if app_service_interested(r)] - self._set_before_and_after(ret, rows) + rows = yield self.runInteraction("get_appservice_room_stream", f) - if rows: - key = "s%d" % max(r["stream_ordering"] for r in rows) - else: - # Assume we didn't get anything because there was nothing to - # get. - key = to_key + ret = yield self._get_events( + [r["event_id"] for r in rows], + get_prev_content=True + ) - return ret, key + self._set_before_and_after(ret, rows, topo_order=from_id is None) - results = yield self.runInteraction("get_appservice_room_stream", f) - defer.returnValue(results) + if rows: + key = "s%d" % max(r["stream_ordering"] for r in rows) + else: + # Assume we didn't get anything because there was nothing to + # get. + key = to_key + + defer.returnValue((ret, key)) @defer.inlineCallbacks def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0, diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 5734198121..f44c4870e3 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -357,7 +357,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store._get_events_txn = Mock(return_value=events) + self.store.get_events = Mock(return_value=events) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) -- cgit 1.5.1 From 310197bab5cf8ed2c26fae522f15f092dbcdff58 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 7 Jun 2016 09:34:50 +0100 Subject: Fix AS retries --- synapse/storage/appservice.py | 4 ++-- tests/storage/test_appservice.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index ffb7d4a25b..a281571637 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -320,10 +320,10 @@ class ApplicationServiceTransactionStore(SQLBaseStore): event_ids = json.loads(entry["event_ids"]) - events = yield self.get_events(event_ids) + event_map = yield self.get_events(event_ids) defer.returnValue(AppServiceTransaction( - service=service, id=entry["txn_id"], events=events + service=service, id=entry["txn_id"], events=event_map.values() )) def _get_oldest_unsent_txn(self, txn, service): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index f44c4870e3..6db4b966db 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -353,21 +353,21 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_oldest_unsent_txn(self): service = Mock(id=self.as_list[0]["id"]) - events = [Mock(event_id="e1"), Mock(event_id="e2")] + events = {"e1": Mock(event_id="e1"), "e2": Mock(event_id="e2")} other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out self.store.get_events = Mock(return_value=events) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) - yield self._insert_txn(service.id, 10, events) + yield self._insert_txn(service.id, 10, events.values()) yield self._insert_txn(service.id, 11, other_events) yield self._insert_txn(service.id, 12, other_events) txn = yield self.store.get_oldest_unsent_txn(service) self.assertEquals(service, txn.service) self.assertEquals(10, txn.id) - self.assertEquals(events, txn.events) + self.assertEquals(events.values(), txn.events) @defer.inlineCallbacks def test_get_appservices_by_state_single(self): -- cgit 1.5.1 From 84379062f9ec259abc302af321d4ed8f5a958c01 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 7 Jun 2016 10:24:50 +0100 Subject: Fix AS retries, but with correct ordering --- synapse/storage/appservice.py | 4 ++-- tests/storage/test_appservice.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index a281571637..d1ee533fac 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -320,10 +320,10 @@ class ApplicationServiceTransactionStore(SQLBaseStore): event_ids = json.loads(entry["event_ids"]) - event_map = yield self.get_events(event_ids) + events = yield self._get_events(event_ids) defer.returnValue(AppServiceTransaction( - service=service, id=entry["txn_id"], events=event_map.values() + service=service, id=entry["txn_id"], events=events )) def _get_oldest_unsent_txn(self, txn, service): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 6db4b966db..3e2862daae 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -353,21 +353,21 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_oldest_unsent_txn(self): service = Mock(id=self.as_list[0]["id"]) - events = {"e1": Mock(event_id="e1"), "e2": Mock(event_id="e2")} + events = [Mock(event_id="e1"), Mock(event_id="e2")] other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events = Mock(return_value=events) + self.store._get_events = Mock(return_value=events) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) - yield self._insert_txn(service.id, 10, events.values()) + yield self._insert_txn(service.id, 10, events) yield self._insert_txn(service.id, 11, other_events) yield self._insert_txn(service.id, 12, other_events) txn = yield self.store.get_oldest_unsent_txn(service) self.assertEquals(service, txn.service) self.assertEquals(10, txn.id) - self.assertEquals(events.values(), txn.events) + self.assertEquals(events, txn.events) @defer.inlineCallbacks def test_get_appservices_by_state_single(self): -- cgit 1.5.1 From 6e7dc7c7dde377794c23d5db6f25ffacfb08e82a Mon Sep 17 00:00:00 2001 From: Negar Fazeli Date: Wed, 8 Jun 2016 19:16:46 +0200 Subject: Fix a bug caused by a change in auth_handler function Fix the relevant unit test cases --- synapse/handlers/register.py | 4 ++-- tests/handlers/test_register.py | 9 +++------ 2 files changed, 5 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index bbc07b045e..e0aaefe7be 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -388,8 +388,8 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - auth_handler = self.hs.get_handlers().auth_handler - token = auth_handler.generate_short_term_login_token(user_id, duration_seconds) + token = self.auth_handler().generate_short_term_login_token( + user_id, duration_seconds) if need_register: yield self.store.register( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 9d5c653b45..69a5e5b1d4 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -41,14 +41,15 @@ class RegistrationTestCase(unittest.TestCase): handlers=None, http_client=None, expire_access_token=True) + self.auth_handler = Mock( + generate_short_term_login_token=Mock(return_value='secret')) self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler self.hs.get_handlers().profile_handler = Mock() self.mock_handler = Mock(spec=[ "generate_short_term_login_token", ]) - - self.hs.get_handlers().auth_handler = self.mock_handler + self.hs.get_auth_handler = Mock(return_value=self.auth_handler) @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): @@ -56,8 +57,6 @@ class RegistrationTestCase(unittest.TestCase): local_part = "someone" display_name = "someone" user_id = "@someone:test" - mock_token = self.mock_handler.generate_short_term_login_token - mock_token.return_value = 'secret' result_user_id, result_token = yield self.handler.get_or_create_user( local_part, display_name, duration_ms) self.assertEquals(result_user_id, user_id) @@ -75,8 +74,6 @@ class RegistrationTestCase(unittest.TestCase): local_part = "frank" display_name = "Frank" user_id = "@frank:test" - mock_token = self.mock_handler.generate_short_term_login_token - mock_token.return_value = 'secret' result_user_id, result_token = yield self.handler.get_or_create_user( local_part, display_name, duration_ms) self.assertEquals(result_user_id, user_id) -- cgit 1.5.1 From 7dbb473339bc41daf6c05b64756f97e011f653f5 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 9 Jun 2016 18:50:38 +0100 Subject: Add function to load config without generating it Renames ``load_config`` to ``load_or_generate_config`` Adds a method called ``load_config`` that just loads the config. The main synapse.app.homeserver will continue to use ``load_or_generate_config`` to retain backwards compat. However new worker processes can use ``load_config`` to load the config avoiding some of the cruft needed to generate the config. As the new ``load_config`` method is expected to be used by new configs it removes support for the legacy commandline overrides that ``load_or_generate_config`` supports --- synapse/app/homeserver.py | 3 +- synapse/config/_base.py | 147 ++++++++++++++++++++++++++++++------------ tests/config/test_generate.py | 2 +- tests/config/test_load.py | 22 ++++++- 4 files changed, 126 insertions(+), 48 deletions(-) (limited to 'tests') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 22e1721fc4..40ffd9bf0d 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -266,10 +266,9 @@ def setup(config_options): HomeServer """ try: - config = HomeServerConfig.load_config( + config = HomeServerConfig.load_or_generate_config( "Synapse Homeserver", config_options, - generate_section="Homeserver" ) except ConfigError as e: sys.stderr.write("\n" + e.message + "\n") diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 7449f36491..af9f17bf7b 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -157,9 +157,40 @@ class Config(object): return default_config, config @classmethod - def load_config(cls, description, argv, generate_section=None): + def load_config(cls, description, argv): + config_parser = argparse.ArgumentParser( + description=description, + ) + config_parser.add_argument( + "-c", "--config-path", + action="append", + metavar="CONFIG_FILE", + help="Specify config file. Can be given multiple times and" + " may specify directories containing *.yaml files." + ) + + config_parser.add_argument( + "--keys-directory", + metavar="DIRECTORY", + help="Where files such as certs and signing keys are stored when" + " their location is given explicitly in the config." + " Defaults to the directory containing the last config file", + ) + + config_args = config_parser.parse_args(argv) + + config_files = find_config_files(search_paths=config_args.config_path) + obj = cls() + obj.read_config_files( + config_files, + keys_directory=config_args.keys_directory, + generate_keys=False, + ) + return obj + @classmethod + def load_or_generate_config(cls, description, argv): config_parser = argparse.ArgumentParser(add_help=False) config_parser.add_argument( "-c", "--config-path", @@ -176,7 +207,7 @@ class Config(object): config_parser.add_argument( "--report-stats", action="store", - help="Stuff", + help="Whether the generated config reports anonymized usage statistics", choices=["yes", "no"] ) config_parser.add_argument( @@ -197,36 +228,11 @@ class Config(object): ) config_args, remaining_args = config_parser.parse_known_args(argv) + config_files = find_config_files(search_paths=config_args.config_path) + generate_keys = config_args.generate_keys - config_files = [] - if config_args.config_path: - for config_path in config_args.config_path: - if os.path.isdir(config_path): - # We accept specifying directories as config paths, we search - # inside that directory for all files matching *.yaml, and then - # we apply them in *sorted* order. - files = [] - for entry in os.listdir(config_path): - entry_path = os.path.join(config_path, entry) - if not os.path.isfile(entry_path): - print ( - "Found subdirectory in config directory: %r. IGNORING." - ) % (entry_path, ) - continue - - if not entry.endswith(".yaml"): - print ( - "Found file in config directory that does not" - " end in '.yaml': %r. IGNORING." - ) % (entry_path, ) - continue - - files.append(entry_path) - - config_files.extend(sorted(files)) - else: - config_files.append(config_path) + obj = cls() if config_args.generate_config: if config_args.report_stats is None: @@ -299,28 +305,43 @@ class Config(object): " -c CONFIG-FILE\"" ) - if config_args.keys_directory: - config_dir_path = config_args.keys_directory - else: - config_dir_path = os.path.dirname(config_args.config_path[-1]) - config_dir_path = os.path.abspath(config_dir_path) + obj.read_config_files( + config_files, + keys_directory=config_args.keys_directory, + generate_keys=generate_keys, + ) + + if generate_keys: + return None + + obj.invoke_all("read_arguments", args) + + return obj + + def read_config_files(self, config_files, keys_directory=None, + generate_keys=False): + if not keys_directory: + keys_directory = os.path.dirname(config_files[-1]) + + config_dir_path = os.path.abspath(keys_directory) specified_config = {} for config_file in config_files: - yaml_config = cls.read_config_file(config_file) + yaml_config = self.read_config_file(config_file) specified_config.update(yaml_config) if "server_name" not in specified_config: raise ConfigError(MISSING_SERVER_NAME) server_name = specified_config["server_name"] - _, config = obj.generate_config( + _, config = self.generate_config( config_dir_path=config_dir_path, server_name=server_name, is_generating_file=False, ) config.pop("log_config") config.update(specified_config) + if "report_stats" not in config: raise ConfigError( MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + @@ -328,11 +349,51 @@ class Config(object): ) if generate_keys: - obj.invoke_all("generate_files", config) + self.invoke_all("generate_files", config) return - obj.invoke_all("read_config", config) - - obj.invoke_all("read_arguments", args) - - return obj + self.invoke_all("read_config", config) + + +def find_config_files(search_paths): + """Finds config files using a list of search paths. If a path is a file + then that file path is added to the list. If a search path is a directory + then all the "*.yaml" files in that directory are added to the list in + sorted order. + + Args: + search_paths(list(str)): A list of paths to search. + + Returns: + list(str): A list of file paths. + """ + + config_files = [] + if search_paths: + for config_path in search_paths: + if os.path.isdir(config_path): + # We accept specifying directories as config paths, we search + # inside that directory for all files matching *.yaml, and then + # we apply them in *sorted* order. + files = [] + for entry in os.listdir(config_path): + entry_path = os.path.join(config_path, entry) + if not os.path.isfile(entry_path): + print ( + "Found subdirectory in config directory: %r. IGNORING." + ) % (entry_path, ) + continue + + if not entry.endswith(".yaml"): + print ( + "Found file in config directory that does not" + " end in '.yaml': %r. IGNORING." + ) % (entry_path, ) + continue + + files.append(entry_path) + + config_files.extend(sorted(files)) + else: + config_files.append(config_path) + return config_files diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index 4329d73974..8f57fbeb23 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -30,7 +30,7 @@ class ConfigGenerationTestCase(unittest.TestCase): shutil.rmtree(self.dir) def test_generate_config_generates_files(self): - HomeServerConfig.load_config("", [ + HomeServerConfig.load_or_generate_config("", [ "--generate-config", "-c", self.file, "--report-stats=yes", diff --git a/tests/config/test_load.py b/tests/config/test_load.py index bf46233c5c..161a87d7e3 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -34,6 +34,8 @@ class ConfigLoadingTestCase(unittest.TestCase): self.generate_config_and_remove_lines_containing("server_name") with self.assertRaises(Exception): HomeServerConfig.load_config("", ["-c", self.file]) + with self.assertRaises(Exception): + HomeServerConfig.load_or_generate_config("", ["-c", self.file]) def test_generates_and_loads_macaroon_secret_key(self): self.generate_config() @@ -54,11 +56,24 @@ class ConfigLoadingTestCase(unittest.TestCase): "was: %r" % (config.macaroon_secret_key,) ) + config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) + self.assertTrue( + hasattr(config, "macaroon_secret_key"), + "Want config to have attr macaroon_secret_key" + ) + if len(config.macaroon_secret_key) < 5: + self.fail( + "Want macaroon secret key to be string of at least length 5," + "was: %r" % (config.macaroon_secret_key,) + ) + def test_load_succeeds_if_macaroon_secret_key_missing(self): self.generate_config_and_remove_lines_containing("macaroon") config1 = HomeServerConfig.load_config("", ["-c", self.file]) config2 = HomeServerConfig.load_config("", ["-c", self.file]) + config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key) + self.assertEqual(config1.macaroon_secret_key, config3.macaroon_secret_key) def test_disable_registration(self): self.generate_config() @@ -70,14 +85,17 @@ class ConfigLoadingTestCase(unittest.TestCase): config = HomeServerConfig.load_config("", ["-c", self.file]) self.assertFalse(config.enable_registration) + config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) + self.assertFalse(config.enable_registration) + # Check that either config value is clobbered by the command line. - config = HomeServerConfig.load_config("", [ + config = HomeServerConfig.load_or_generate_config("", [ "-c", self.file, "--enable-registration" ]) self.assertTrue(config.enable_registration) def generate_config(self): - HomeServerConfig.load_config("", [ + HomeServerConfig.load_or_generate_config("", [ "--generate-config", "-c", self.file, "--report-stats=yes", -- cgit 1.5.1 From 0113ad36ee7bc315aa162c42277b90764825f219 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 17 Jun 2016 15:13:13 +0100 Subject: Enable use_frozen_events in tests --- tests/utils.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/utils.py b/tests/utils.py index e19ae581e0..6e41ae1ff6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -54,6 +54,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.trusted_third_party_id_servers = [] config.room_invite_state_types = [] + config.use_frozen_dicts = True config.database_config = {"name": "sqlite3"} if "clock" not in kargs: -- cgit 1.5.1 From 0a32208e5dde4980a5962f17e9b27f2e28e1f3f1 Mon Sep 17 00:00:00 2001 From: Martin Weinelt Date: Mon, 6 Jun 2016 02:05:57 +0200 Subject: Rework ldap integration with ldap3 Use the pure-python ldap3 library, which eliminates the need for a system dependency. Offer both a `search` and `simple_bind` mode, for more sophisticated ldap scenarios. - `search` tries to find a matching DN within the `user_base` while employing the `user_filter`, then tries the bind when a single matching DN was found. - `simple_bind` tries the bind against a specific DN by combining the localpart and `user_base` Offer support for STARTTLS on a plain connection. The configuration was changed to reflect these new possibilities. Signed-off-by: Martin Weinelt --- synapse/config/ldap.py | 102 +++++++++++++++------ synapse/handlers/auth.py | 203 ++++++++++++++++++++++++++++++++++------- synapse/python_dependencies.py | 3 + tests/utils.py | 1 + 4 files changed, 249 insertions(+), 60 deletions(-) (limited to 'tests') diff --git a/synapse/config/ldap.py b/synapse/config/ldap.py index 9c14593a99..d83c2230be 100644 --- a/synapse/config/ldap.py +++ b/synapse/config/ldap.py @@ -13,40 +13,88 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +from ._base import Config, ConfigError + + +MISSING_LDAP3 = ( + "Missing ldap3 library. This is required for LDAP Authentication." +) + + +class LDAPMode(object): + SIMPLE = "simple", + SEARCH = "search", + + LIST = (SIMPLE, SEARCH) class LDAPConfig(Config): def read_config(self, config): - ldap_config = config.get("ldap_config", None) - if ldap_config: - self.ldap_enabled = ldap_config.get("enabled", False) - self.ldap_server = ldap_config["server"] - self.ldap_port = ldap_config["port"] - self.ldap_tls = ldap_config.get("tls", False) - self.ldap_search_base = ldap_config["search_base"] - self.ldap_search_property = ldap_config["search_property"] - self.ldap_email_property = ldap_config["email_property"] - self.ldap_full_name_property = ldap_config["full_name_property"] - else: - self.ldap_enabled = False - self.ldap_server = None - self.ldap_port = None - self.ldap_tls = False - self.ldap_search_base = None - self.ldap_search_property = None - self.ldap_email_property = None - self.ldap_full_name_property = None + ldap_config = config.get("ldap_config", {}) + + self.ldap_enabled = ldap_config.get("enabled", False) + + if self.ldap_enabled: + # verify dependencies are available + try: + import ldap3 + ldap3 # to stop unused lint + except ImportError: + raise ConfigError(MISSING_LDAP3) + + self.ldap_mode = LDAPMode.SIMPLE + + # verify config sanity + self.require_keys(ldap_config, [ + "uri", + "base", + "attributes", + ]) + + self.ldap_uri = ldap_config["uri"] + self.ldap_start_tls = ldap_config.get("start_tls", False) + self.ldap_base = ldap_config["base"] + self.ldap_attributes = ldap_config["attributes"] + + if "bind_dn" in ldap_config: + self.ldap_mode = LDAPMode.SEARCH + self.require_keys(ldap_config, [ + "bind_dn", + "bind_password", + ]) + + self.ldap_bind_dn = ldap_config["bind_dn"] + self.ldap_bind_password = ldap_config["bind_password"] + self.ldap_filter = ldap_config.get("filter", None) + + # verify attribute lookup + self.require_keys(ldap_config['attributes'], [ + "uid", + "name", + "mail", + ]) + + def require_keys(self, config, required): + missing = [key for key in required if key not in config] + if missing: + raise ConfigError( + "LDAP enabled but missing required config values: {}".format( + ", ".join(missing) + ) + ) def default_config(self, **kwargs): return """\ # ldap_config: # enabled: true - # server: "ldap://localhost" - # port: 389 - # tls: false - # search_base: "ou=Users,dc=example,dc=com" - # search_property: "cn" - # email_property: "email" - # full_name_property: "givenName" + # uri: "ldap://ldap.example.com:389" + # start_tls: true + # base: "ou=users,dc=example,dc=com" + # attributes: + # uid: "cn" + # mail: "email" + # name: "givenName" + # #bind_dn: + # #bind_password: + # #filter: "(objectClass=posixAccount)" """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index b38f81e999..968095c141 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -20,6 +20,7 @@ from synapse.api.constants import LoginType from synapse.types import UserID from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.util.async import run_on_reactor +from synapse.config.ldap import LDAPMode from twisted.web.client import PartialDownloadError @@ -28,6 +29,12 @@ import bcrypt import pymacaroons import simplejson +try: + import ldap3 +except ImportError: + ldap3 = None + pass + import synapse.util.stringutils as stringutils @@ -50,17 +57,20 @@ class AuthHandler(BaseHandler): self.INVALID_TOKEN_HTTP_STATUS = 401 self.ldap_enabled = hs.config.ldap_enabled - self.ldap_server = hs.config.ldap_server - self.ldap_port = hs.config.ldap_port - self.ldap_tls = hs.config.ldap_tls - self.ldap_search_base = hs.config.ldap_search_base - self.ldap_search_property = hs.config.ldap_search_property - self.ldap_email_property = hs.config.ldap_email_property - self.ldap_full_name_property = hs.config.ldap_full_name_property - - if self.ldap_enabled is True: - import ldap - logger.info("Import ldap version: %s", ldap.__version__) + if self.ldap_enabled: + if not ldap3: + raise RuntimeError( + 'Missing ldap3 library. This is required for LDAP Authentication.' + ) + self.ldap_mode = hs.config.ldap_mode + self.ldap_uri = hs.config.ldap_uri + self.ldap_start_tls = hs.config.ldap_start_tls + self.ldap_base = hs.config.ldap_base + self.ldap_filter = hs.config.ldap_filter + self.ldap_attributes = hs.config.ldap_attributes + if self.ldap_mode == LDAPMode.SEARCH: + self.ldap_bind_dn = hs.config.ldap_bind_dn + self.ldap_bind_password = hs.config.ldap_bind_password self.hs = hs # FIXME better possibility to access registrationHandler later? @@ -452,40 +462,167 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def _check_ldap_password(self, user_id, password): - if not self.ldap_enabled: - logger.debug("LDAP not configured") + """ Attempt to authenticate a user against an LDAP Server + and register an account if none exists. + + Returns: + True if authentication against LDAP was successful + """ + + if not ldap3 or not self.ldap_enabled: defer.returnValue(False) - import ldap + if self.ldap_mode not in LDAPMode.LIST: + raise RuntimeError( + 'Invalid ldap mode specified: {mode}'.format( + mode=self.ldap_mode + ) + ) - logger.info("Authenticating %s with LDAP" % user_id) try: - ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port) - logger.debug("Connecting LDAP server at %s" % ldap_url) - l = ldap.initialize(ldap_url) - if self.ldap_tls: - logger.debug("Initiating TLS") - self._connection.start_tls_s() + server = ldap3.Server(self.ldap_uri) + logger.debug( + "Attempting ldap connection with %s", + self.ldap_uri + ) - local_name = UserID.from_string(user_id).localpart + localpart = UserID.from_string(user_id).localpart + if self.ldap_mode == LDAPMode.SIMPLE: + # bind with the the local users ldap credentials + bind_dn = "{prop}={value},{base}".format( + prop=self.ldap_attributes['uid'], + value=localpart, + base=self.ldap_base + ) + conn = ldap3.Connection(server, bind_dn, password) + logger.debug( + "Established ldap connection in simple mode: %s", + conn + ) - dn = "%s=%s, %s" % ( - self.ldap_search_property, - local_name, - self.ldap_search_base) - logger.debug("DN for LDAP authentication: %s" % dn) + if self.ldap_start_tls: + conn.start_tls() + logger.debug( + "Upgraded ldap connection in simple mode through StartTLS: %s", + conn + ) + + conn.bind() + + elif self.ldap_mode == LDAPMode.SEARCH: + # connect with preconfigured credentials and search for local user + conn = ldap3.Connection( + server, + self.ldap_bind_dn, + self.ldap_bind_password + ) + logger.debug( + "Established ldap connection in search mode: %s", + conn + ) + + if self.ldap_start_tls: + conn.start_tls() + logger.debug( + "Upgraded ldap connection in search mode through StartTLS: %s", + conn + ) - l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8')) + conn.bind() + # find matching dn + query = "({prop}={value})".format( + prop=self.ldap_attributes['uid'], + value=localpart + ) + if self.ldap_filter: + query = "(&{query}{filter})".format( + query=query, + filter=self.ldap_filter + ) + logger.debug("ldap search filter: %s", query) + result = conn.search(self.ldap_base, query) + + if result and len(conn.response) == 1: + # found exactly one result + user_dn = conn.response[0]['dn'] + logger.debug('ldap search found dn: %s', user_dn) + + # unbind and reconnect, rebind with found dn + conn.unbind() + conn = ldap3.Connection( + server, + user_dn, + password, + auto_bind=True + ) + else: + # found 0 or > 1 results, abort! + logger.warn( + "ldap search returned unexpected (%d!=1) amount of results", + len(conn.response) + ) + defer.returnValue(False) + + logger.info( + "User authenticated against ldap server: %s", + conn + ) + + # check for existing account, if none exists, create one if not (yield self.does_user_exist(user_id)): - handler = self.hs.get_handlers().registration_handler - user_id, access_token = ( - yield handler.register(localpart=local_name) + # query user metadata for account creation + query = "({prop}={value})".format( + prop=self.ldap_attributes['uid'], + value=localpart + ) + + if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter: + query = "(&{filter}{user_filter})".format( + filter=query, + user_filter=self.ldap_filter + ) + logger.debug("ldap registration filter: %s", query) + + result = conn.search( + search_base=self.ldap_base, + search_filter=query, + attributes=[ + self.ldap_attributes['name'], + self.ldap_attributes['mail'] + ] ) + if len(conn.response) == 1: + attrs = conn.response[0]['attributes'] + mail = attrs[self.ldap_attributes['mail']][0] + name = attrs[self.ldap_attributes['name']][0] + + # create account + registration_handler = self.hs.get_handlers().registration_handler + user_id, access_token = ( + yield registration_handler.register(localpart=localpart) + ) + + # TODO: bind email, set displayname with data from ldap directory + + logger.info( + "ldap registration successful: %d: %s (%s, %)", + user_id, + localpart, + name, + mail + ) + else: + logger.warn( + "ldap registration failed: unexpected (%d!=1) amount of results", + len(result) + ) + defer.returnValue(False) + defer.returnValue(True) - except ldap.LDAPError, e: - logger.warn("LDAP error: %s", e) + except ldap3.core.exceptions.LDAPException as e: + logger.warn("Error during ldap authentication: %s", e) defer.returnValue(False) @defer.inlineCallbacks diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index e0a7a19777..e024cec0a2 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -48,6 +48,9 @@ CONDITIONAL_REQUIREMENTS = { "Jinja2>=2.8": ["Jinja2>=2.8"], "bleach>=1.4.2": ["bleach>=1.4.2"], }, + "ldap": { + "ldap3>=1.0": ["ldap3>=1.0"], + }, } diff --git a/tests/utils.py b/tests/utils.py index 6e41ae1ff6..ed547bc39b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -56,6 +56,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.use_frozen_dicts = True config.database_config = {"name": "sqlite3"} + config.ldap_enabled = False if "clock" not in kargs: kargs["clock"] = MockClock() -- cgit 1.5.1 From 2455ad8468ea3d372d0f3b3828efa10419ad68ad Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 24 Jun 2016 13:34:20 +0100 Subject: Remove room name & alias test as get_room_name_and_alias is now gone --- tests/replication/slave/storage/test_events.py | 41 -------------------------- 1 file changed, 41 deletions(-) (limited to 'tests') diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 17587fda00..f33e6f60fb 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -58,47 +58,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def tearDown(self): [unpatch() for unpatch in self.unpatches] - @defer.inlineCallbacks - def test_room_name_and_aliases(self): - create = yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.persist(type="m.room.member", key=USER_ID, membership="join") - yield self.persist(type="m.room.name", key="", name="name1") - yield self.persist( - type="m.room.aliases", key="blue", aliases=["#1:blue"] - ) - yield self.replicate() - yield self.check( - "get_room_name_and_aliases", (ROOM_ID,), ("name1", ["#1:blue"]) - ) - - # Set the room name. - yield self.persist(type="m.room.name", key="", name="name2") - yield self.replicate() - yield self.check( - "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#1:blue"]) - ) - - # Set the room aliases. - yield self.persist( - type="m.room.aliases", key="blue", aliases=["#2:blue"] - ) - yield self.replicate() - yield self.check( - "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#2:blue"]) - ) - - # Leave and join the room clobbering the state. - yield self.persist(type="m.room.member", key=USER_ID, membership="leave") - yield self.persist( - type="m.room.member", key=USER_ID, membership="join", - reset_state=[create] - ) - yield self.replicate() - - yield self.check( - "get_room_name_and_aliases", (ROOM_ID,), (None, []) - ) - @defer.inlineCallbacks def test_room_members(self): create = yield self.persist(type="m.room.create", key="", creator=USER_ID) -- cgit 1.5.1 From 7335f0addae9ff473403eaaffd7d2b02a9f1991f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Jul 2016 14:44:25 +0100 Subject: Add ReadWriteLock --- synapse/util/async.py | 82 +++++++++++++++++++++++++++++++++++++++++++++ tests/util/test_rwlock.py | 85 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 167 insertions(+) create mode 100644 tests/util/test_rwlock.py (limited to 'tests') diff --git a/synapse/util/async.py b/synapse/util/async.py index 40be7fe7e3..c84b23ff46 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -194,3 +194,85 @@ class Linearizer(object): self.key_to_defer.pop(key, None) defer.returnValue(_ctx_manager()) + + +class ReadWriteLock(object): + """A deferred style read write lock. + + Example: + + with (yield read_write_lock.read("test_key")): + # do some work + """ + + # IMPLEMENTATION NOTES + # + # We track the most recent queued reader and writer deferreds (which get + # resolved when they release the lock). + # + # Read: We know its safe to acquire a read lock when the latest writer has + # been resolved. The new reader is appeneded to the list of latest readers. + # + # Write: We know its safe to acquire the write lock when both the latest + # writers and readers have been resolved. The new writer replaces the latest + # writer. + + def __init__(self): + # Latest readers queued + self.key_to_current_readers = {} + + # Latest writer queued + self.key_to_current_writer = {} + + @defer.inlineCallbacks + def read(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.setdefault(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + curr_readers.add(new_defer) + + # We wait for the latest writer to finish writing. We can safely ignore + # any existing readers... as they're readers. + yield curr_writer + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + self.key_to_current_readers.get(key, set()).discard(new_defer) + + defer.returnValue(_ctx_manager()) + + @defer.inlineCallbacks + def write(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.get(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + # We wait on all latest readers and writer. + to_wait_on = list(curr_readers) + if curr_writer: + to_wait_on.append(curr_writer) + + # We can clear the list of current readers since the new writer waits + # for them to finish. + curr_readers.clear() + self.key_to_current_writer[key] = new_defer + + yield defer.gatherResults(to_wait_on) + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + if self.key_to_current_writer[key] == new_defer: + self.key_to_current_writer.pop(key) + + defer.returnValue(_ctx_manager()) diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py new file mode 100644 index 0000000000..1d745ae1a7 --- /dev/null +++ b/tests/util/test_rwlock.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests import unittest + +from synapse.util.async import ReadWriteLock + + +class ReadWriteLockTestCase(unittest.TestCase): + + def _assert_called_before_not_after(self, lst, first_false): + for i, d in enumerate(lst[:first_false]): + self.assertTrue(d.called, msg="%d was unexpectedly false" % i) + + for i, d in enumerate(lst[first_false:]): + self.assertFalse( + d.called, msg="%d was unexpectedly true" % (i + first_false) + ) + + def test_rwlock(self): + rwlock = ReadWriteLock() + + key = object() + + ds = [ + rwlock.read(key), # 0 + rwlock.read(key), # 1 + rwlock.write(key), # 2 + rwlock.write(key), # 3 + rwlock.read(key), # 4 + rwlock.read(key), # 5 + rwlock.write(key), # 6 + ] + + self._assert_called_before_not_after(ds, 2) + + with ds[0].result: + self._assert_called_before_not_after(ds, 2) + self._assert_called_before_not_after(ds, 2) + + with ds[1].result: + self._assert_called_before_not_after(ds, 2) + self._assert_called_before_not_after(ds, 3) + + with ds[2].result: + self._assert_called_before_not_after(ds, 3) + self._assert_called_before_not_after(ds, 4) + + with ds[3].result: + self._assert_called_before_not_after(ds, 4) + self._assert_called_before_not_after(ds, 6) + + with ds[5].result: + self._assert_called_before_not_after(ds, 6) + self._assert_called_before_not_after(ds, 6) + + with ds[4].result: + self._assert_called_before_not_after(ds, 6) + self._assert_called_before_not_after(ds, 7) + + with ds[6].result: + pass + + d = rwlock.write(key) + self.assertTrue(d.called) + with d.result: + pass + + d = rwlock.read(key) + self.assertTrue(d.called) + with d.result: + pass -- cgit 1.5.1 From 0136a522b18a734db69171d60566f501c0ced663 Mon Sep 17 00:00:00 2001 From: Negar Fazeli Date: Fri, 8 Jul 2016 16:53:18 +0200 Subject: Bug fix: expire invalid access tokens --- synapse/api/auth.py | 3 +++ synapse/handlers/auth.py | 5 +++-- synapse/handlers/register.py | 6 +++--- synapse/rest/client/v1/register.py | 2 +- tests/api/test_auth.py | 31 ++++++++++++++++++++++++++++++- tests/handlers/test_register.py | 4 ++-- 6 files changed, 42 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index a4d658a9d0..521a52e001 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -629,7 +629,10 @@ class Auth(object): except AuthError: # TODO(daniel): Remove this fallback when all existing access tokens # have been re-issued as macaroons. + if self.hs.config.expire_access_token: + raise ret = yield self._look_up_user_by_access_token(token) + defer.returnValue(ret) @defer.inlineCallbacks diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index e259213a36..5a0ed9d6b9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -637,12 +637,13 @@ class AuthHandler(BaseHandler): yield self.store.add_refresh_token_to_user(user_id, refresh_token) defer.returnValue(refresh_token) - def generate_access_token(self, user_id, extra_caveats=None): + def generate_access_token(self, user_id, extra_caveats=None, + duration_in_ms=(60 * 60 * 1000)): extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") now = self.hs.get_clock().time_msec() - expiry = now + (60 * 60 * 1000) + expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 8c3381df8a..6b33b27149 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -360,7 +360,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, localpart, displayname, duration_seconds, + def get_or_create_user(self, localpart, displayname, duration_in_ms, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -390,8 +390,8 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.auth_handler().generate_short_term_login_token( - user_id, duration_seconds) + token = self.auth_handler().generate_access_token( + user_id, None, duration_in_ms) if need_register: yield self.store.register( diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index ce7099b18f..8e1f1b7845 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -429,7 +429,7 @@ class CreateUserRestServlet(ClientV1RestServlet): user_id, token = yield handler.get_or_create_user( localpart=localpart, displayname=displayname, - duration_seconds=duration_seconds, + duration_in_ms=(duration_seconds * 1000), password_hash=password_hash ) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index ad269af0ec..960c23d631 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -281,7 +281,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user,)) - macaroon.add_first_party_caveat("time < 1") # ms + macaroon.add_first_party_caveat("time < -2000") # ms self.hs.clock.now = 5000 # seconds self.hs.config.expire_access_token = True @@ -293,3 +293,32 @@ class AuthTestCase(unittest.TestCase): yield self.auth.get_user_from_macaroon(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("Invalid macaroon", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_with_valid_duration(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user_id = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + macaroon.add_first_party_caveat("time < 900000000") # ms + + self.hs.clock.now = 5000 # seconds + self.hs.config.expire_access_token = True + + user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize()) + user = user_info["user"] + self.assertEqual(UserID.from_string(user_id), user) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 69a5e5b1d4..a7de3c7c17 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -42,12 +42,12 @@ class RegistrationTestCase(unittest.TestCase): http_client=None, expire_access_token=True) self.auth_handler = Mock( - generate_short_term_login_token=Mock(return_value='secret')) + generate_access_token=Mock(return_value='secret')) self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler self.hs.get_handlers().profile_handler = Mock() self.mock_handler = Mock(spec=[ - "generate_short_term_login_token", + "generate_access_token", ]) self.hs.get_auth_handler = Mock(return_value=self.auth_handler) -- cgit 1.5.1 From a98d2152049b0a61426ed3d8b6ac872a9ca3f535 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 14 Jul 2016 15:59:25 +0100 Subject: Add filter param to /messages API --- synapse/handlers/message.py | 16 ++++++++++++---- synapse/rest/client/v1/room.py | 11 ++++++++++- tests/storage/event_injector.py | 1 + tests/storage/test_events.py | 12 ++++++------ 4 files changed, 29 insertions(+), 11 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ad2753c1b5..dc76d34a52 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -66,7 +66,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, - as_client_event=True): + as_client_event=True, event_filter=None): """Get messages in a room. Args: @@ -75,11 +75,11 @@ class MessageHandler(BaseHandler): pagin_config (synapse.api.streams.PaginationConfig): The pagination config rules to apply, if any. as_client_event (bool): True to get events in client-server format. + event_filter (Filter): Filter to apply to results or None Returns: dict: Pagination API results """ user_id = requester.user.to_string() - data_source = self.hs.get_event_sources().sources["room"] if pagin_config.from_token: room_token = pagin_config.from_token.room_key @@ -129,8 +129,13 @@ class MessageHandler(BaseHandler): room_id, max_topo ) - events, next_key = yield data_source.get_pagination_rows( - requester.user, source_config, room_id + events, next_key = yield self.store.paginate_room_events( + room_id=room_id, + from_key=source_config.from_key, + to_key=source_config.to_key, + direction=source_config.direction, + limit=source_config.limit, + event_filter=event_filter, ) next_token = pagin_config.from_token.copy_and_replace( @@ -144,6 +149,9 @@ class MessageHandler(BaseHandler): "end": next_token.to_string(), }) + if event_filter: + events = event_filter.filter(events) + events = yield filter_events_for_client( self.store, user_id, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 86fbe2747d..866a1e9120 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -20,12 +20,14 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.api.errors import SynapseError, Codes, AuthError from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership +from synapse.api.filtering import Filter from synapse.types import UserID, RoomID, RoomAlias from synapse.events.utils import serialize_event from synapse.http.servlet import parse_json_object_from_request import logging import urllib +import ujson as json logger = logging.getLogger(__name__) @@ -327,12 +329,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet): request, default_limit=10, ) as_client_event = "raw" not in request.args + filter_bytes = request.args.get("filter", None) + if filter_bytes: + filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8") + event_filter = Filter(json.loads(filter_json)) + else: + event_filter = None handler = self.handlers.message_handler msgs = yield handler.get_messages( room_id=room_id, requester=requester, pagin_config=pagination_config, - as_client_event=as_client_event + as_client_event=as_client_event, + event_filter=event_filter, ) defer.returnValue((200, msgs)) diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py index f22ba8db89..38556da9a7 100644 --- a/tests/storage/event_injector.py +++ b/tests/storage/event_injector.py @@ -30,6 +30,7 @@ class EventInjector: def create_room(self, room): builder = self.event_builder_factory.new({ "type": EventTypes.Create, + "sender": "", "room_id": room.to_string(), "content": {}, }) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 18a6cff0c7..3762b38e37 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -37,7 +37,7 @@ class EventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_count_daily_messages(self): - self.db_pool.runQuery("DELETE FROM stats_reporting") + yield self.db_pool.runQuery("DELETE FROM stats_reporting") self.hs.clock.now = 100 @@ -60,7 +60,7 @@ class EventsStoreTestCase(unittest.TestCase): # it isn't old enough. count = yield self.store.count_daily_messages() self.assertIsNone(count) - self._assert_stats_reporting(1, self.hs.clock.now) + yield self._assert_stats_reporting(1, self.hs.clock.now) # Already reported yesterday, two new events from today. yield self.event_injector.inject_message(room, user, "Yeah they are!") @@ -68,21 +68,21 @@ class EventsStoreTestCase(unittest.TestCase): self.hs.clock.now += 60 * 60 * 24 count = yield self.store.count_daily_messages() self.assertEqual(2, count) # 2 since yesterday - self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever + yield self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever # Last reported too recently. yield self.event_injector.inject_message(room, user, "Who could disagree?") self.hs.clock.now += 60 * 60 * 22 count = yield self.store.count_daily_messages() self.assertIsNone(count) - self._assert_stats_reporting(4, self.hs.clock.now) + yield self._assert_stats_reporting(4, self.hs.clock.now) # Last reported too long ago yield self.event_injector.inject_message(room, user, "No one.") self.hs.clock.now += 60 * 60 * 26 count = yield self.store.count_daily_messages() self.assertIsNone(count) - self._assert_stats_reporting(5, self.hs.clock.now) + yield self._assert_stats_reporting(5, self.hs.clock.now) # And now let's actually report something yield self.event_injector.inject_message(room, user, "Indeed.") @@ -92,7 +92,7 @@ class EventsStoreTestCase(unittest.TestCase): self.hs.clock.now += (60 * 60 * 24) + 50 count = yield self.store.count_daily_messages() self.assertEqual(3, count) - self._assert_stats_reporting(8, self.hs.clock.now) + yield self._assert_stats_reporting(8, self.hs.clock.now) @defer.inlineCallbacks def _get_last_stream_token(self): -- cgit 1.5.1 From f863a52ceacf69ab19b073383be80603a2f51c0a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 Jul 2016 13:19:07 +0100 Subject: Add device_id support to /login Add a 'devices' table to the storage, as well as a 'device_id' column to refresh_tokens. Allow the client to pass a device_id, and initial_device_display_name, to /login. If login is successful, then register the device in the devices table if it wasn't known already. If no device_id was supplied, make one up. Associate the device_id with the access token and refresh token, so that we can get at it again later. Ensure that the device_id is copied from the refresh token to the access_token when the token is refreshed. --- synapse/handlers/auth.py | 19 +++--- synapse/handlers/device.py | 71 ++++++++++++++++++++ synapse/rest/client/v1/login.py | 39 ++++++++++- synapse/rest/client/v2_alpha/tokenrefresh.py | 10 ++- synapse/server.py | 5 ++ synapse/storage/__init__.py | 3 + synapse/storage/devices.py | 77 ++++++++++++++++++++++ synapse/storage/registration.py | 28 +++++--- synapse/storage/schema/delta/33/devices.sql | 21 ++++++ .../schema/delta/33/refreshtoken_device.sql | 16 +++++ tests/handlers/test_device.py | 75 +++++++++++++++++++++ tests/storage/test_registration.py | 21 ++++-- 12 files changed, 354 insertions(+), 31 deletions(-) create mode 100644 synapse/handlers/device.py create mode 100644 synapse/storage/devices.py create mode 100644 synapse/storage/schema/delta/33/devices.sql create mode 100644 synapse/storage/schema/delta/33/refreshtoken_device.sql create mode 100644 tests/handlers/test_device.py (limited to 'tests') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 983994fa95..ce9bc18849 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -361,7 +361,7 @@ class AuthHandler(BaseHandler): return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id): + def get_login_tuple_for_user_id(self, user_id, device_id=None): """ Gets login tuple for the user with the given user ID. @@ -372,6 +372,7 @@ class AuthHandler(BaseHandler): Args: user_id (str): canonical User ID + device_id (str): the device ID to associate with the access token Returns: A tuple of: The access token for the user's session. @@ -380,9 +381,9 @@ class AuthHandler(BaseHandler): StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - logger.info("Logging in user %s", user_id) - access_token = yield self.issue_access_token(user_id) - refresh_token = yield self.issue_refresh_token(user_id) + logger.info("Logging in user %s on device %s", user_id, device_id) + access_token = yield self.issue_access_token(user_id, device_id) + refresh_token = yield self.issue_refresh_token(user_id, device_id) defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks @@ -638,15 +639,17 @@ class AuthHandler(BaseHandler): defer.returnValue(False) @defer.inlineCallbacks - def issue_access_token(self, user_id): + def issue_access_token(self, user_id, device_id=None): access_token = self.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token) + yield self.store.add_access_token_to_user(user_id, access_token, + device_id) defer.returnValue(access_token) @defer.inlineCallbacks - def issue_refresh_token(self, user_id): + def issue_refresh_token(self, user_id, device_id=None): refresh_token = self.generate_refresh_token(user_id) - yield self.store.add_refresh_token_to_user(user_id, refresh_token) + yield self.store.add_refresh_token_to_user(user_id, refresh_token, + device_id) defer.returnValue(refresh_token) def generate_access_token(self, user_id, extra_caveats=None, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py new file mode 100644 index 0000000000..8d7d9874f8 --- /dev/null +++ b/synapse/handlers/device.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.errors import StoreError +from synapse.util import stringutils +from twisted.internet import defer +from ._base import BaseHandler + +import logging + +logger = logging.getLogger(__name__) + + +class DeviceHandler(BaseHandler): + def __init__(self, hs): + super(DeviceHandler, self).__init__(hs) + + @defer.inlineCallbacks + def check_device_registered(self, user_id, device_id, + initial_device_display_name): + """ + If the given device has not been registered, register it with the + supplied display name. + + If no device_id is supplied, we make one up. + + Args: + user_id (str): @user:id + device_id (str | None): device id supplied by client + initial_device_display_name (str | None): device display name from + client + Returns: + str: device id (generated if none was supplied) + """ + if device_id is not None: + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=True, + ) + defer.returnValue(device_id) + + # if the device id is not specified, we'll autogen one, but loop a few + # times in case of a clash. + attempts = 0 + while attempts < 5: + try: + device_id = stringutils.random_string_with_symbols(16) + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=False, + ) + defer.returnValue(device_id) + except StoreError: + attempts += 1 + + raise StoreError(500, "Couldn't generate a device ID.") diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a1f2ba8773..e8b791519c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet): self.servername = hs.config.server_name self.http_client = hs.get_simple_http_client() self.auth_handler = self.hs.get_auth_handler() + self.device_handler = self.hs.get_device_handler() def on_GET(self, request): flows = [] @@ -149,14 +150,16 @@ class LoginRestServlet(ClientV1RestServlet): user_id=user_id, password=login_submission["password"], ) + device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -168,14 +171,16 @@ class LoginRestServlet(ClientV1RestServlet): user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) + device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -252,8 +257,13 @@ class LoginRestServlet(ClientV1RestServlet): auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if registered_user_id: + device_id = yield self._register_device( + registered_user_id, login_submission + ) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(registered_user_id) + yield auth_handler.get_login_tuple_for_user_id( + registered_user_id, device_id + ) ) result = { "user_id": registered_user_id, @@ -262,6 +272,9 @@ class LoginRestServlet(ClientV1RestServlet): "home_server": self.hs.hostname, } else: + # TODO: we should probably check that the register isn't going + # to fonx/change our user_id before registering the device + device_id = yield self._register_device(user_id, login_submission) user_id, access_token = ( yield self.handlers.registration_handler.register(localpart=user) ) @@ -300,6 +313,26 @@ class LoginRestServlet(ClientV1RestServlet): return (user, attributes) + def _register_device(self, user_id, login_submission): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) login_submission: dictionary supplied to /login call, from + which we pull device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get( + "initial_device_display_name") + return self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + class SAML2RestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login/saml2", releases=()) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 8270e8787f..0d312c91d4 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -39,9 +39,13 @@ class TokenRefreshRestServlet(RestServlet): try: old_refresh_token = body["refresh_token"] auth_handler = self.hs.get_auth_handler() - (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( - old_refresh_token, auth_handler.generate_refresh_token) - new_access_token = yield auth_handler.issue_access_token(user_id) + refresh_result = yield self.store.exchange_refresh_token( + old_refresh_token, auth_handler.generate_refresh_token + ) + (user_id, new_refresh_token, device_id) = refresh_result + new_access_token = yield auth_handler.issue_access_token( + user_id, device_id + ) defer.returnValue((200, { "access_token": new_access_token, "refresh_token": new_refresh_token, diff --git a/synapse/server.py b/synapse/server.py index d49a1a8a96..e8b166990d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -25,6 +25,7 @@ from twisted.enterprise import adbapi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.appservice.api import ApplicationServiceApi from synapse.federation import initialize_http_replication +from synapse.handlers.device import DeviceHandler from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.notifier import Notifier from synapse.api.auth import Auth @@ -92,6 +93,7 @@ class HomeServer(object): 'typing_handler', 'room_list_handler', 'auth_handler', + 'device_handler', 'application_service_api', 'application_service_scheduler', 'application_service_handler', @@ -197,6 +199,9 @@ class HomeServer(object): def build_auth_handler(self): return AuthHandler(self) + def build_device_handler(self): + return DeviceHandler(self) + def build_application_service_api(self): return ApplicationServiceApi(self) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 1c93e18f9d..73fb334dd6 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -14,6 +14,8 @@ # limitations under the License. from twisted.internet import defer + +from synapse.storage.devices import DeviceStore from .appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore ) @@ -80,6 +82,7 @@ class DataStore(RoomMemberStore, RoomStore, EventPushActionsStore, OpenIdStore, ClientIpStore, + DeviceStore, ): def __init__(self, db_conn, hs): diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py new file mode 100644 index 0000000000..9065e96d28 --- /dev/null +++ b/synapse/storage/devices.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from ._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class DeviceStore(SQLBaseStore): + @defer.inlineCallbacks + def store_device(self, user_id, device_id, + initial_device_display_name, + ignore_if_known=True): + """Ensure the given device is known; add it to the store if not + + Args: + user_id (str): id of user associated with the device + device_id (str): id of device + initial_device_display_name (str): initial displayname of the + device + ignore_if_known (bool): ignore integrity errors which mean the + device is already known + Returns: + defer.Deferred + Raises: + StoreError: if ignore_if_known is False and the device was already + known + """ + try: + yield self._simple_insert( + "devices", + values={ + "user_id": user_id, + "device_id": device_id, + "display_name": initial_device_display_name + }, + desc="store_device", + or_ignore=ignore_if_known, + ) + except Exception as e: + logger.error("store_device with device_id=%s failed: %s", + device_id, e) + raise StoreError(500, "Problem storing device.") + + def get_device(self, user_id, device_id): + """Retrieve a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred for a namedtuple containing the device information + Raises: + StoreError: if the device is not found + """ + return self._simple_select_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcols=("user_id", "device_id", "display_name"), + desc="get_device", + ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index d957a629dc..26ef1cfd8a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -31,12 +31,14 @@ class RegistrationStore(SQLBaseStore): self.clock = hs.get_clock() @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token): + def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. Args: user_id (str): The user ID. token (str): The new access token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -47,18 +49,21 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_access_token_to_user", ) @defer.inlineCallbacks - def add_refresh_token_to_user(self, user_id, token): + def add_refresh_token_to_user(self, user_id, token, device_id=None): """Adds a refresh token for the given user. Args: user_id (str): The user ID. token (str): The new refresh token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -69,7 +74,8 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_refresh_token_to_user", ) @@ -291,18 +297,18 @@ class RegistrationStore(SQLBaseStore): ) def exchange_refresh_token(self, refresh_token, token_generator): - """Exchange a refresh token for a new access token and refresh token. + """Exchange a refresh token for a new one. Doing so invalidates the old refresh token - refresh tokens are single use. Args: - token (str): The refresh token of a user. + refresh_token (str): The refresh token of a user. token_generator (fn: str -> str): Function which, when given a user ID, returns a unique refresh token for that user. This function must never return the same value twice. Returns: - tuple of (user_id, refresh_token) + tuple of (user_id, new_refresh_token, device_id) Raises: StoreError if no user was found with that refresh token. """ @@ -314,12 +320,13 @@ class RegistrationStore(SQLBaseStore): ) def _exchange_refresh_token(self, txn, old_token, token_generator): - sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" + sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?" txn.execute(sql, (old_token,)) rows = self.cursor_to_dict(txn) if not rows: raise StoreError(403, "Did not recognize refresh token") user_id = rows[0]["user_id"] + device_id = rows[0]["device_id"] # TODO(danielwh): Maybe perform a validation on the macaroon that # macaroon.user_id == user_id. @@ -328,7 +335,7 @@ class RegistrationStore(SQLBaseStore): sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" txn.execute(sql, (new_token, old_token,)) - return user_id, new_token + return user_id, new_token, device_id @defer.inlineCallbacks def is_server_admin(self, user): @@ -356,7 +363,8 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id" + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/schema/delta/33/devices.sql new file mode 100644 index 0000000000..eca7268d82 --- /dev/null +++ b/synapse/storage/schema/delta/33/devices.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE devices ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + display_name TEXT, + CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) +); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/33/refreshtoken_device.sql new file mode 100644 index 0000000000..b21da00dde --- /dev/null +++ b/synapse/storage/schema/delta/33/refreshtoken_device.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE refresh_tokens ADD COLUMN device_id BIGINT; diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py new file mode 100644 index 0000000000..cc6512ccc7 --- /dev/null +++ b/tests/handlers/test_device.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.handlers.device import DeviceHandler +from tests import unittest +from tests.utils import setup_test_homeserver + + +class DeviceHandlers(object): + def __init__(self, hs): + self.device_handler = DeviceHandler(hs) + + +class DeviceTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver(handlers=None) + self.hs.handlers = handlers = DeviceHandlers(self.hs) + self.handler = handlers.device_handler + + @defer.inlineCallbacks + def test_device_is_created_if_doesnt_exist(self): + res = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="display name" + ) + self.assertEqual(res, "fco") + + dev = yield self.handler.store.get_device("boris", "fco") + self.assertEqual(dev["display_name"], "display name") + + @defer.inlineCallbacks + def test_device_is_preserved_if_exists(self): + res1 = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="display name" + ) + self.assertEqual(res1, "fco") + + res2 = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="new display name" + ) + self.assertEqual(res2, "fco") + + dev = yield self.handler.store.get_device("boris", "fco") + self.assertEqual(dev["display_name"], "display name") + + @defer.inlineCallbacks + def test_device_id_is_made_up_if_unspecified(self): + device_id = yield self.handler.check_device_registered( + user_id="theresa", + device_id=None, + initial_device_display_name="display" + ) + + dev = yield self.handler.store.get_device("theresa", device_id) + self.assertEqual(dev["display_name"], "display") diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index b8384c98d8..b03ca303a2 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -38,6 +38,7 @@ class RegistrationStoreTestCase(unittest.TestCase): "BcDeFgHiJkLmNoPqRsTuVwXyZa" ] self.pwhash = "{xx1}123456789" + self.device_id = "akgjhdjklgshg" @defer.inlineCallbacks def test_register(self): @@ -64,13 +65,15 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_add_tokens(self): yield self.store.register(self.user_id, self.tokens[0], self.pwhash) - yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) + yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], + self.device_id) result = yield self.store.get_user_by_access_token(self.tokens[1]) self.assertDictContainsSubset( { "name": self.user_id, + "device_id": self.device_id, }, result ) @@ -80,20 +83,24 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_exchange_refresh_token_valid(self): uid = stringutils.random_string(32) + device_id = stringutils.random_string(16) generator = TokenGenerator() last_token = generator.generate(uid) self.db_pool.runQuery( - "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", - (uid, last_token,)) + "INSERT INTO refresh_tokens(user_id, token, device_id) " + "VALUES(?,?,?)", + (uid, last_token, device_id)) - (found_user_id, refresh_token) = yield self.store.exchange_refresh_token( - last_token, generator.generate) + (found_user_id, refresh_token, device_id) = \ + yield self.store.exchange_refresh_token(last_token, + generator.generate) self.assertEqual(uid, found_user_id) rows = yield self.db_pool.runQuery( - "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, )) - self.assertEqual([(refresh_token,)], rows) + "SELECT token, device_id FROM refresh_tokens WHERE user_id = ?", + (uid, )) + self.assertEqual([(refresh_token, device_id)], rows) # We issued token 1, then exchanged it for token 2 expected_refresh_token = u"%s-%d" % (uid, 2,) self.assertEqual(expected_refresh_token, refresh_token) -- cgit 1.5.1 From 0da0d0a29d807c481152b1580acbbe36f24cf771 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 19 Jul 2016 13:12:22 +0100 Subject: rest/client/v2_alpha/register.py: Refactor flow somewhat. This is meant to be an *almost* non-functional change, with the exception that it fixes what looks a lot like a bug in that it only calls `auth_handler.add_threepid` and `add_pusher` once instead of three times. The idea is to move the generation of the `access_token` out of `registration_handler.register`, because `access_token`s now require a device_id, and we only want to generate a device_id once registration has been successful. --- synapse/rest/client/v2_alpha/register.py | 177 ++++++++++++++++------------ tests/rest/client/v2_alpha/test_register.py | 3 +- 2 files changed, 104 insertions(+), 76 deletions(-) (limited to 'tests') diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index e8d34b06b0..707bde0f34 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -199,92 +199,55 @@ class RegisterRestServlet(RestServlet): "Already registered user ID %r for this session", registered_user_id ) - access_token = yield self.auth_handler.issue_access_token(registered_user_id) - refresh_token = yield self.auth_handler.issue_refresh_token( - registered_user_id + # don't re-register the email address + add_email = False + else: + # NB: This may be from the auth handler and NOT from the POST + if 'password' not in params: + raise SynapseError(400, "Missing password.", + Codes.MISSING_PARAM) + + desired_username = params.get("username", None) + new_password = params.get("password", None) + guest_access_token = params.get("guest_access_token", None) + + (registered_user_id, _) = yield self.registration_handler.register( + localpart=desired_username, + password=new_password, + guest_access_token=guest_access_token, + generate_token=False, ) - defer.returnValue((200, { - "user_id": registered_user_id, - "access_token": access_token, - "home_server": self.hs.hostname, - "refresh_token": refresh_token, - })) - # NB: This may be from the auth handler and NOT from the POST - if 'password' not in params: - raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM) + # remember that we've now registered that user account, and with + # what user ID (since the user may not have specified) + self.auth_handler.set_session_data( + session_id, "registered_user_id", registered_user_id + ) - desired_username = params.get("username", None) - new_password = params.get("password", None) - guest_access_token = params.get("guest_access_token", None) + add_email = True - (user_id, token) = yield self.registration_handler.register( - localpart=desired_username, - password=new_password, - guest_access_token=guest_access_token, + access_token = yield self.auth_handler.issue_access_token( + registered_user_id ) - # remember that we've now registered that user account, and with what - # user ID (since the user may not have specified) - self.auth_handler.set_session_data( - session_id, "registered_user_id", user_id - ) - - if result and LoginType.EMAIL_IDENTITY in result: + if add_email and result and LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] - - for reqd in ['medium', 'address', 'validated_at']: - if reqd not in threepid: - logger.info("Can't add incomplete 3pid") - else: - yield self.auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], - ) - - # And we add an email pusher for them by default, but only - # if email notifications are enabled (so people don't start - # getting mail spam where they weren't before if email - # notifs are set up on a home server) - if ( - self.hs.config.email_enable_notifs and - self.hs.config.email_notif_for_new_users - ): - # Pull the ID of the access token back out of the db - # It would really make more sense for this to be passed - # up when the access token is saved, but that's quite an - # invasive change I'd rather do separately. - user_tuple = yield self.store.get_user_by_access_token( - token - ) - - yield self.hs.get_pusherpool().add_pusher( - user_id=user_id, - access_token=user_tuple["token_id"], - kind="email", - app_id="m.email", - app_display_name="Email Notifications", - device_display_name=threepid["address"], - pushkey=threepid["address"], - lang=None, # We don't know a user's language here - data={}, - ) - - if 'bind_email' in params and params['bind_email']: + reqd = ('medium', 'address', 'validated_at') + if all(x in threepid for x in reqd): + yield self._register_email_threepid( + registered_user_id, threepid, access_token + ) + # XXX why is bind_email not protected by this? + else: + logger.info("Can't add incomplete 3pid") + if params.get("bind_email"): logger.info("bind_email specified: binding") - - emailThreepid = result[LoginType.EMAIL_IDENTITY] - threepid_creds = emailThreepid['threepid_creds'] - logger.debug("Binding emails %s to %s" % ( - emailThreepid, user_id - )) - yield self.identity_handler.bind_threepid(threepid_creds, user_id) + yield self._bind_email(registered_user_id, threepid) else: logger.info("bind_email not specified: not binding email") - result = yield self._create_registration_details(user_id, token) + result = yield self._create_registration_details(registered_user_id, + access_token) defer.returnValue((200, result)) def on_OPTIONS(self, _): @@ -324,6 +287,70 @@ class RegisterRestServlet(RestServlet): ) defer.returnValue((yield self._create_registration_details(user_id, token))) + @defer.inlineCallbacks + def _register_email_threepid(self, user_id, threepid, token): + """Add an email address as a 3pid identifier + + Also adds an email pusher for the email address, if configured in the + HS config + + Args: + user_id (str): id of user + threepid (object): m.login.email.identity auth response + token (str): access_token for the user + Returns: + defer.Deferred: + """ + yield self.auth_handler.add_threepid( + user_id, + threepid['medium'], + threepid['address'], + threepid['validated_at'], + ) + + # And we add an email pusher for them by default, but only + # if email notifications are enabled (so people don't start + # getting mail spam where they weren't before if email + # notifs are set up on a home server) + if (self.hs.config.email_enable_notifs and + self.hs.config.email_notif_for_new_users): + # Pull the ID of the access token back out of the db + # It would really make more sense for this to be passed + # up when the access token is saved, but that's quite an + # invasive change I'd rather do separately. + user_tuple = yield self.store.get_user_by_access_token( + token + ) + token_id = user_tuple["token_id"] + + yield self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="email", + app_id="m.email", + app_display_name="Email Notifications", + device_display_name=threepid["address"], + pushkey=threepid["address"], + lang=None, # We don't know a user's language here + data={}, + ) + defer.returnValue() + + def _bind_email(self, user_id, email_threepid): + """Bind emails to the given user_id on the identity server + + Args: + user_id (str): user id to bind the emails to + email_threepid (object): m.login.email.identity auth response + Returns: + defer.Deferred: + """ + threepid_creds = email_threepid['threepid_creds'] + logger.debug("Binding emails %s to %s" % ( + email_threepid, user_id + )) + return self.identity_handler.bind_threepid(threepid_creds, user_id) + @defer.inlineCallbacks def _create_registration_details(self, user_id, token): refresh_token = yield self.auth_handler.issue_refresh_token(user_id) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index cda0a2b27c..9a4215fef7 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -114,7 +114,8 @@ class RegisterRestServletTestCase(unittest.TestCase): "username": "kermit", "password": "monkey" }, None) - self.registration_handler.register = Mock(return_value=(user_id, token)) + self.registration_handler.register = Mock(return_value=(user_id, None)) + self.auth_handler.issue_access_token = Mock(return_value=token) (code, result) = yield self.servlet.on_POST(self.request) self.assertEquals(code, 200) -- cgit 1.5.1 From 40cbffb2d2ca0166f1377ac4ec5988046ea4ca10 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 19 Jul 2016 18:46:19 +0100 Subject: Further registration refactoring * `RegistrationHandler.appservice_register` no longer issues an access token: instead it is left for the caller to do it. (There are two of these, one in `synapse/rest/client/v1/register.py`, which now simply calls `AuthHandler.issue_access_token`, and the other in `synapse/rest/client/v2_alpha/register.py`, which is covered below). * In `synapse/rest/client/v2_alpha/register.py`, move the generation of access_tokens into `_create_registration_details`. This means that the normal flow no longer needs to call `AuthHandler.issue_access_token`; the shared-secret flow can tell `RegistrationHandler.register` not to generate a token; and the appservice flow continues to work despite the above change. --- synapse/handlers/register.py | 13 +++++--- synapse/rest/client/v1/register.py | 4 ++- synapse/rest/client/v2_alpha/register.py | 50 +++++++++++++++++++++-------- synapse/storage/registration.py | 6 ++-- tests/rest/client/v2_alpha/test_register.py | 6 +++- 5 files changed, 57 insertions(+), 22 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 6b33b27149..94b19d0cb0 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -99,8 +99,13 @@ class RegistrationHandler(BaseHandler): localpart : The local part of the user ID to register. If None, one will be generated. password (str) : The password to assign to this user so they can - login again. This can be None which means they cannot login again - via a password (e.g. the user is an application service user). + login again. This can be None which means they cannot login again + via a password (e.g. the user is an application service user). + generate_token (bool): Whether a new access token should be + generated. Having this be True should be considered deprecated, + since it offers no means of associating a device_id with the + access_token. Instead you should call auth_handler.issue_access_token + after registration. Returns: A tuple of (user_id, access_token). Raises: @@ -196,15 +201,13 @@ class RegistrationHandler(BaseHandler): user_id, allowed_appservice=service ) - token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, - token=token, password_hash="", appservice_id=service_id, create_profile_with_localpart=user.localpart, ) - defer.returnValue((user_id, token)) + defer.returnValue(user_id) @defer.inlineCallbacks def check_recaptcha(self, ip, private_key, challenge, response): diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 8e1f1b7845..28b59952c3 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -60,6 +60,7 @@ class RegisterRestServlet(ClientV1RestServlet): # TODO: persistent storage self.sessions = {} self.enable_registration = hs.config.enable_registration + self.auth_handler = hs.get_auth_handler() def on_GET(self, request): if self.hs.config.enable_registration_captcha: @@ -299,9 +300,10 @@ class RegisterRestServlet(ClientV1RestServlet): user_localpart = register_json["user"].encode("utf-8") handler = self.handlers.registration_handler - (user_id, token) = yield handler.appservice_register( + user_id = yield handler.appservice_register( user_localpart, as_token ) + token = yield self.auth_handler.issue_access_token(user_id) self._remove_session(session) defer.returnValue({ "user_id": user_id, diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 5db953a1e3..04004cfbbd 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -226,19 +226,17 @@ class RegisterRestServlet(RestServlet): add_email = True - access_token = yield self.auth_handler.issue_access_token( + result = yield self._create_registration_details( registered_user_id ) if add_email and result and LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] yield self._register_email_threepid( - registered_user_id, threepid, access_token, + registered_user_id, threepid, result["access_token"], params.get("bind_email") ) - result = yield self._create_registration_details(registered_user_id, - access_token) defer.returnValue((200, result)) def on_OPTIONS(self, _): @@ -246,10 +244,10 @@ class RegisterRestServlet(RestServlet): @defer.inlineCallbacks def _do_appservice_registration(self, username, as_token): - (user_id, token) = yield self.registration_handler.appservice_register( + user_id = yield self.registration_handler.appservice_register( username, as_token ) - defer.returnValue((yield self._create_registration_details(user_id, token))) + defer.returnValue((yield self._create_registration_details(user_id))) @defer.inlineCallbacks def _do_shared_secret_registration(self, username, password, mac): @@ -273,10 +271,12 @@ class RegisterRestServlet(RestServlet): 403, "HMAC incorrect", ) - (user_id, token) = yield self.registration_handler.register( - localpart=username, password=password + (user_id, _) = yield self.registration_handler.register( + localpart=username, password=password, generate_token=False, ) - defer.returnValue((yield self._create_registration_details(user_id, token))) + + result = yield self._create_registration_details(user_id) + defer.returnValue(result) @defer.inlineCallbacks def _register_email_threepid(self, user_id, threepid, token, bind_email): @@ -349,11 +349,31 @@ class RegisterRestServlet(RestServlet): defer.returnValue() @defer.inlineCallbacks - def _create_registration_details(self, user_id, token): - refresh_token = yield self.auth_handler.issue_refresh_token(user_id) + def _create_registration_details(self, user_id): + """Complete registration of newly-registered user + + Issues access_token and refresh_token, and builds the success response + body. + + Args: + (str) user_id: full canonical @user:id + + + Returns: + defer.Deferred: (object) dictionary for response from /register + """ + + access_token = yield self.auth_handler.issue_access_token( + user_id + ) + + refresh_token = yield self.auth_handler.issue_refresh_token( + user_id + ) + defer.returnValue({ "user_id": user_id, - "access_token": token, + "access_token": access_token, "home_server": self.hs.hostname, "refresh_token": refresh_token, }) @@ -366,7 +386,11 @@ class RegisterRestServlet(RestServlet): generate_token=False, make_guest=True ) - access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"]) + access_token = self.auth_handler.generate_access_token( + user_id, ["guest = true"] + ) + # XXX the "guest" caveat is not copied by /tokenrefresh. That's ok + # so long as we don't return a refresh_token here. defer.returnValue((200, { "user_id": user_id, "access_token": access_token, diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 26ef1cfd8a..9a92b35361 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -81,14 +81,16 @@ class RegistrationStore(SQLBaseStore): ) @defer.inlineCallbacks - def register(self, user_id, token, password_hash, + def register(self, user_id, token=None, password_hash=None, was_guest=False, make_guest=False, appservice_id=None, create_profile_with_localpart=None, admin=False): """Attempts to register an account. Args: user_id (str): The desired user ID to register. - token (str): The desired access token to use for this user. + token (str): The desired access token to use for this user. If this + is not None, the given access token is associated with the user + id. password_hash (str): Optional. The password hash for this user. was_guest (bool): Optional. Whether this is a guest account being upgraded to a non-guest account. diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 9a4215fef7..ccbb8776d3 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -61,8 +61,10 @@ class RegisterRestServletTestCase(unittest.TestCase): "id": "1234" } self.registration_handler.appservice_register = Mock( - return_value=(user_id, token) + return_value=user_id ) + self.auth_handler.issue_access_token = Mock(return_value=token) + (code, result) = yield self.servlet.on_POST(self.request) self.assertEquals(code, 200) det_data = { @@ -126,6 +128,8 @@ class RegisterRestServletTestCase(unittest.TestCase): } self.assertDictContainsSubset(det_data, result) self.assertIn("refresh_token", result) + self.auth_handler.issue_access_token.assert_called_once_with( + user_id) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False -- cgit 1.5.1 From b97a1356b149f62e5b2c28b09818d74b445cc635 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 19 Jul 2016 18:38:26 +0100 Subject: Register a device_id in the /v2/register flow. This doesn't cover *all* of the registration flows, but it does cover the most common ones: in particular: shared_secret registration, appservice registration, and normal user/pass registration. Pull device_id from the registration parameters. Register the device in the devices table. Associate the device with the returned access and refresh tokens. Profit. --- synapse/rest/client/v2_alpha/register.py | 54 +++++++++++++++++++++-------- tests/rest/client/v2_alpha/test_register.py | 13 +++++-- 2 files changed, 49 insertions(+), 18 deletions(-) (limited to 'tests') diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b7e03ea9d1..d401722224 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -93,6 +93,7 @@ class RegisterRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def on_POST(self, request): @@ -145,7 +146,7 @@ class RegisterRestServlet(RestServlet): if isinstance(desired_username, basestring): result = yield self._do_appservice_registration( - desired_username, request.args["access_token"][0] + desired_username, request.args["access_token"][0], body ) defer.returnValue((200, result)) # we throw for non 200 responses return @@ -155,7 +156,7 @@ class RegisterRestServlet(RestServlet): # FIXME: Should we really be determining if this is shared secret # auth based purely on the 'mac' key? result = yield self._do_shared_secret_registration( - desired_username, desired_password, body["mac"] + desired_username, desired_password, body ) defer.returnValue((200, result)) # we throw for non 200 responses return @@ -236,7 +237,7 @@ class RegisterRestServlet(RestServlet): add_email = True result = yield self._create_registration_details( - registered_user_id + registered_user_id, body ) if add_email and result and LoginType.EMAIL_IDENTITY in result: @@ -252,14 +253,14 @@ class RegisterRestServlet(RestServlet): return 200, {} @defer.inlineCallbacks - def _do_appservice_registration(self, username, as_token): + def _do_appservice_registration(self, username, as_token, body): user_id = yield self.registration_handler.appservice_register( username, as_token ) - defer.returnValue((yield self._create_registration_details(user_id))) + defer.returnValue((yield self._create_registration_details(user_id, body))) @defer.inlineCallbacks - def _do_shared_secret_registration(self, username, password, mac): + def _do_shared_secret_registration(self, username, password, body): if not self.hs.config.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") @@ -267,7 +268,7 @@ class RegisterRestServlet(RestServlet): # str() because otherwise hmac complains that 'unicode' does not # have the buffer interface - got_mac = str(mac) + got_mac = str(body["mac"]) want_mac = hmac.new( key=self.hs.config.registration_shared_secret, @@ -284,7 +285,7 @@ class RegisterRestServlet(RestServlet): localpart=username, password=password, generate_token=False, ) - result = yield self._create_registration_details(user_id) + result = yield self._create_registration_details(user_id, body) defer.returnValue(result) @defer.inlineCallbacks @@ -358,35 +359,58 @@ class RegisterRestServlet(RestServlet): defer.returnValue() @defer.inlineCallbacks - def _create_registration_details(self, user_id): + def _create_registration_details(self, user_id, body): """Complete registration of newly-registered user - Issues access_token and refresh_token, and builds the success response - body. + Allocates device_id if one was not given; also creates access_token + and refresh_token. Args: (str) user_id: full canonical @user:id - + (object) body: dictionary supplied to /register call, from + which we pull device_id and initial_device_name Returns: defer.Deferred: (object) dictionary for response from /register """ + device_id = yield self._register_device(user_id, body) access_token = yield self.auth_handler.issue_access_token( - user_id + user_id, device_id=device_id ) refresh_token = yield self.auth_handler.issue_refresh_token( - user_id + user_id, device_id=device_id ) - defer.returnValue({ "user_id": user_id, "access_token": access_token, "home_server": self.hs.hostname, "refresh_token": refresh_token, + "device_id": device_id, }) + def _register_device(self, user_id, body): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) body: dictionary supplied to /register call, from + which we pull device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + # register the user's device + device_id = body.get("device_id") + initial_display_name = body.get("initial_device_display_name") + device_id = self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + return device_id + @defer.inlineCallbacks def _do_guest_registration(self): if not self.hs.config.allow_guest_access: diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index ccbb8776d3..3bd7065e32 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -30,6 +30,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.registration_handler = Mock() self.identity_handler = Mock() self.login_handler = Mock() + self.device_handler = Mock() # do the dance to hook it up to the hs global self.handlers = Mock( @@ -42,6 +43,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_auth_handler = Mock(return_value=self.auth_handler) + self.hs.get_device_handler = Mock(return_value=self.device_handler) self.hs.config.enable_registration = True # init the thing we're testing @@ -107,9 +109,11 @@ class RegisterRestServletTestCase(unittest.TestCase): def test_POST_user_valid(self): user_id = "@kermit:muppet" token = "kermits_access_token" + device_id = "frogfone" self.request_data = json.dumps({ "username": "kermit", - "password": "monkey" + "password": "monkey", + "device_id": device_id, }) self.registration_handler.check_username = Mock(return_value=True) self.auth_result = (True, None, { @@ -118,18 +122,21 @@ class RegisterRestServletTestCase(unittest.TestCase): }, None) self.registration_handler.register = Mock(return_value=(user_id, None)) self.auth_handler.issue_access_token = Mock(return_value=token) + self.device_handler.check_device_registered = \ + Mock(return_value=device_id) (code, result) = yield self.servlet.on_POST(self.request) self.assertEquals(code, 200) det_data = { "user_id": user_id, "access_token": token, - "home_server": self.hs.hostname + "home_server": self.hs.hostname, + "device_id": device_id, } self.assertDictContainsSubset(det_data, result) self.assertIn("refresh_token", result) self.auth_handler.issue_access_token.assert_called_once_with( - user_id) + user_id, device_id=device_id) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False -- cgit 1.5.1 From ec041b335ecb20008609c8603338ab8c586615be Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 20 Jul 2016 15:25:40 +0100 Subject: Record device_id in client_ips Record the device_id when we add a client ip; it's somewhat redundant as we could get it via the access_token, but it will make querying rather easier. --- synapse/api/auth.py | 29 +++++++++++++++++++++++------ synapse/storage/client_ips.py | 3 ++- tests/api/test_auth.py | 10 +++++++++- 3 files changed, 34 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index ff7d816cfc..eca8513905 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -586,6 +586,10 @@ class Auth(object): token_id = user_info["token_id"] is_guest = user_info["is_guest"] + # device_id may not be present if get_user_by_access_token has been + # stubbed out. + device_id = user_info.get("device_id") + ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( "User-Agent", @@ -597,7 +601,8 @@ class Auth(object): user=user, access_token=access_token, ip=ip_addr, - user_agent=user_agent + user_agent=user_agent, + device_id=device_id, ) if is_guest and not allow_guest: @@ -695,6 +700,7 @@ class Auth(object): "user": user, "is_guest": True, "token_id": None, + "device_id": None, } elif rights == "delete_pusher": # We don't store these tokens in the database @@ -702,13 +708,20 @@ class Auth(object): "user": user, "is_guest": False, "token_id": None, + "device_id": None, } else: - # This codepath exists so that we can actually return a - # token ID, because we use token IDs in place of device - # identifiers throughout the codebase. - # TODO(daniel): Remove this fallback when device IDs are - # properly implemented. + # This codepath exists for several reasons: + # * so that we can actually return a token ID, which is used + # in some parts of the schema (where we probably ought to + # use device IDs instead) + # * the only way we currently have to invalidate an + # access_token is by removing it from the database, so we + # have to check here that it is still in the db + # * some attributes (notably device_id) aren't stored in the + # macaroon. They probably should be. + # TODO: build the dictionary from the macaroon once the + # above are fixed ret = yield self._look_up_user_by_access_token(macaroon_str) if ret["user"] != user: logger.error( @@ -782,10 +795,14 @@ class Auth(object): self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", errcode=Codes.UNKNOWN_TOKEN ) + # we use ret.get() below because *lots* of unit tests stub out + # get_user_by_access_token in a way where it only returns a couple of + # the fields. user_info = { "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), "is_guest": False, + "device_id": ret.get("device_id"), } defer.returnValue(user_info) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index a90990e006..74330a8ddf 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -35,7 +35,7 @@ class ClientIpStore(SQLBaseStore): super(ClientIpStore, self).__init__(hs) @defer.inlineCallbacks - def insert_client_ip(self, user, access_token, ip, user_agent): + def insert_client_ip(self, user, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) key = (user.to_string(), access_token, ip) @@ -59,6 +59,7 @@ class ClientIpStore(SQLBaseStore): "access_token": access_token, "ip": ip, "user_agent": user_agent, + "device_id": device_id, }, values={ "last_seen": now, diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 960c23d631..e91723ca3d 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -45,6 +45,7 @@ class AuthTestCase(unittest.TestCase): user_info = { "name": self.test_user, "token_id": "ditto", + "device_id": "device", } self.store.get_user_by_access_token = Mock(return_value=user_info) @@ -143,7 +144,10 @@ class AuthTestCase(unittest.TestCase): # TODO(danielwh): Remove this mock when we remove the # get_user_by_access_token fallback. self.store.get_user_by_access_token = Mock( - return_value={"name": "@baldrick:matrix.org"} + return_value={ + "name": "@baldrick:matrix.org", + "device_id": "device", + } ) user_id = "@baldrick:matrix.org" @@ -158,6 +162,10 @@ class AuthTestCase(unittest.TestCase): user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) + # TODO: device_id should come from the macaroon, but currently comes + # from the db. + self.assertEqual(user_info["device_id"], "device") + @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): user_id = "@baldrick:matrix.org" -- cgit 1.5.1 From bc8f265f0a8443e918b17a94f4b2fa319e70a21f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 20 Jul 2016 16:34:00 +0100 Subject: GET /devices endpoint implement a GET /devices endpoint which lists all of the user's devices. It also returns the last IP where we saw that device, so there is some dancing to fish that out of the user_ips table. --- synapse/handlers/device.py | 27 ++++++++ synapse/rest/__init__.py | 2 + synapse/rest/client/v2_alpha/_base.py | 13 ++-- synapse/rest/client/v2_alpha/devices.py | 51 ++++++++++++++ synapse/storage/client_ips.py | 72 ++++++++++++++++++++ synapse/storage/devices.py | 22 +++++- synapse/storage/schema/delta/33/user_ips_index.sql | 16 +++++ tests/handlers/test_device.py | 78 ++++++++++++++++++---- tests/storage/test_client_ips.py | 62 +++++++++++++++++ tests/storage/test_devices.py | 71 ++++++++++++++++++++ 10 files changed, 397 insertions(+), 17 deletions(-) create mode 100644 synapse/rest/client/v2_alpha/devices.py create mode 100644 synapse/storage/schema/delta/33/user_ips_index.sql create mode 100644 tests/storage/test_client_ips.py create mode 100644 tests/storage/test_devices.py (limited to 'tests') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 8d7d9874f8..6bbbf59e52 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -69,3 +69,30 @@ class DeviceHandler(BaseHandler): attempts += 1 raise StoreError(500, "Couldn't generate a device ID.") + + @defer.inlineCallbacks + def get_devices_by_user(self, user_id): + """ + Retrieve the given user's devices + + Args: + user_id (str): + Returns: + defer.Deferred: dict[str, dict[str, X]]: map from device_id to + info on the device + """ + + devices = yield self.store.get_devices_by_user(user_id) + + ips = yield self.store.get_last_client_ip_by_device( + devices=((user_id, device_id) for device_id in devices.keys()) + ) + + for device_id in devices.keys(): + ip = ips.get((user_id, device_id), {}) + devices[device_id].update({ + "last_seen_ts": ip.get("last_seen"), + "last_seen_ip": ip.get("ip"), + }) + + defer.returnValue(devices) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 8b223e032b..14227f1cdb 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -46,6 +46,7 @@ from synapse.rest.client.v2_alpha import ( account_data, report_event, openid, + devices, ) from synapse.http.server import JsonResource @@ -90,3 +91,4 @@ class ClientRestResource(JsonResource): account_data.register_servlets(hs, client_resource) report_event.register_servlets(hs, client_resource) openid.register_servlets(hs, client_resource) + devices.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index b6faa2b0e6..20e765f48f 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -25,7 +25,9 @@ import logging logger = logging.getLogger(__name__) -def client_v2_patterns(path_regex, releases=(0,)): +def client_v2_patterns(path_regex, releases=(0,), + v2_alpha=True, + unstable=True): """Creates a regex compiled client path with the correct client path prefix. @@ -35,9 +37,12 @@ def client_v2_patterns(path_regex, releases=(0,)): Returns: SRE_Pattern """ - patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)] - unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") - patterns.append(re.compile("^" + unstable_prefix + path_regex)) + patterns = [] + if v2_alpha: + patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)) + if unstable: + unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") + patterns.append(re.compile("^" + unstable_prefix + path_regex)) for release in releases: new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) patterns.append(re.compile("^" + new_prefix + path_regex)) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py new file mode 100644 index 0000000000..5cf8bd1afa --- /dev/null +++ b/synapse/rest/client/v2_alpha/devices.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.http.servlet import RestServlet + +from ._base import client_v2_patterns + +import logging + + +logger = logging.getLogger(__name__) + + +class DevicesRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(DevicesRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + devices = yield self.device_handler.get_devices_by_user( + requester.user.to_string() + ) + defer.returnValue((200, {"devices": devices})) + + +def register_servlets(hs, http_server): + DevicesRestServlet(hs).register(http_server) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index a90990e006..07161496ca 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -13,10 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from ._base import SQLBaseStore, Cache from twisted.internet import defer +logger = logging.getLogger(__name__) # Number of msec of granularity to store the user IP 'last seen' time. Smaller # times give more inserts into the database even for readonly API hits @@ -66,3 +69,72 @@ class ClientIpStore(SQLBaseStore): desc="insert_client_ip", lock=False, ) + + @defer.inlineCallbacks + def get_last_client_ip_by_device(self, devices): + """For each device_id listed, give the user_ip it was last seen on + + Args: + devices (iterable[(str, str)]): list of (user_id, device_id) pairs + + Returns: + defer.Deferred: resolves to a dict, where the keys + are (user_id, device_id) tuples. The values are also dicts, with + keys giving the column names + """ + + res = yield self.runInteraction( + "get_last_client_ip_by_device", + self._get_last_client_ip_by_device_txn, + retcols=( + "user_id", + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ), + devices=devices + ) + + ret = {(d["user_id"], d["device_id"]): d for d in res} + defer.returnValue(ret) + + @classmethod + def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols): + def where_clause_for_device(d): + return + + where_clauses = [] + bindings = [] + for (user_id, device_id) in devices: + if device_id is None: + where_clauses.append("(user_id = ? AND device_id IS NULL)") + bindings.extend((user_id, )) + else: + where_clauses.append("(user_id = ? AND device_id = ?)") + bindings.extend((user_id, device_id)) + + inner_select = ( + "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips " + "WHERE %(where)s " + "GROUP BY user_id, device_id" + ) % { + "where": " OR ".join(where_clauses), + } + + sql = ( + "SELECT %(retcols)s FROM user_ips " + "JOIN (%(inner_select)s) ips ON" + " user_ips.last_seen = ips.mls AND" + " user_ips.user_id = ips.user_id AND" + " (user_ips.device_id = ips.device_id OR" + " (user_ips.device_id IS NULL AND ips.device_id IS NULL)" + " )" + ) % { + "retcols": ",".join("user_ips." + c for c in retcols), + "inner_select": inner_select, + } + + txn.execute(sql, bindings) + return cls.cursor_to_dict(txn) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 9065e96d28..1cc6e07f2b 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -65,7 +65,7 @@ class DeviceStore(SQLBaseStore): user_id (str): The ID of the user which owns the device device_id (str): The ID of the device to retrieve Returns: - defer.Deferred for a namedtuple containing the device information + defer.Deferred for a dict containing the device information Raises: StoreError: if the device is not found """ @@ -75,3 +75,23 @@ class DeviceStore(SQLBaseStore): retcols=("user_id", "device_id", "display_name"), desc="get_device", ) + + @defer.inlineCallbacks + def get_devices_by_user(self, user_id): + """Retrieve all of a user's registered devices. + + Args: + user_id (str): + Returns: + defer.Deferred: resolves to a dict from device_id to a dict + containing "device_id", "user_id" and "display_name" for each + device. + """ + devices = yield self._simple_select_list( + table="devices", + keyvalues={"user_id": user_id}, + retcols=("user_id", "device_id", "display_name"), + desc="get_devices_by_user" + ) + + defer.returnValue({d["device_id"]: d for d in devices}) diff --git a/synapse/storage/schema/delta/33/user_ips_index.sql b/synapse/storage/schema/delta/33/user_ips_index.sql new file mode 100644 index 0000000000..8a05677d42 --- /dev/null +++ b/synapse/storage/schema/delta/33/user_ips_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX user_ips_device_id ON user_ips(user_id, device_id, last_seen); diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index cc6512ccc7..c2e12135d6 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -12,25 +12,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from synapse import types from twisted.internet import defer -from synapse.handlers.device import DeviceHandler -from tests import unittest -from tests.utils import setup_test_homeserver - - -class DeviceHandlers(object): - def __init__(self, hs): - self.device_handler = DeviceHandler(hs) +import synapse.handlers.device +import synapse.storage +from tests import unittest, utils class DeviceTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(DeviceTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.DataStore + self.handler = None # type: device.DeviceHandler + self.clock = None # type: utils.MockClock + @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver(handlers=None) - self.hs.handlers = handlers = DeviceHandlers(self.hs) - self.handler = handlers.device_handler + hs = yield utils.setup_test_homeserver(handlers=None) + self.handler = synapse.handlers.device.DeviceHandler(hs) + self.store = hs.get_datastore() + self.clock = hs.get_clock() @defer.inlineCallbacks def test_device_is_created_if_doesnt_exist(self): @@ -73,3 +75,55 @@ class DeviceTestCase(unittest.TestCase): dev = yield self.handler.store.get_device("theresa", device_id) self.assertEqual(dev["display_name"], "display") + + @defer.inlineCallbacks + def test_get_devices_by_user(self): + # check this works for both devices which have a recorded client_ip, + # and those which don't. + user1 = "@boris:aaa" + user2 = "@theresa:bbb" + yield self._record_user(user1, "xyz", "display 0") + yield self._record_user(user1, "fco", "display 1", "token1", "ip1") + yield self._record_user(user1, "abc", "display 2", "token2", "ip2") + yield self._record_user(user1, "abc", "display 2", "token3", "ip3") + + yield self._record_user(user2, "def", "dispkay", "token4", "ip4") + + res = yield self.handler.get_devices_by_user(user1) + self.assertEqual(3, len(res.keys())) + self.assertDictContainsSubset({ + "user_id": user1, + "device_id": "xyz", + "display_name": "display 0", + "last_seen_ip": None, + "last_seen_ts": None, + }, res["xyz"]) + self.assertDictContainsSubset({ + "user_id": user1, + "device_id": "fco", + "display_name": "display 1", + "last_seen_ip": "ip1", + "last_seen_ts": 1000000, + }, res["fco"]) + self.assertDictContainsSubset({ + "user_id": user1, + "device_id": "abc", + "display_name": "display 2", + "last_seen_ip": "ip3", + "last_seen_ts": 3000000, + }, res["abc"]) + + @defer.inlineCallbacks + def _record_user(self, user_id, device_id, display_name, + access_token=None, ip=None): + device_id = yield self.handler.check_device_registered( + user_id=user_id, + device_id=device_id, + initial_device_display_name=display_name + ) + + if ip is not None: + yield self.store.insert_client_ip( + types.UserID.from_string(user_id), + access_token, ip, "user_agent", device_id) + self.clock.advance_time(1000) \ No newline at end of file diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py new file mode 100644 index 0000000000..1f0c0e7c37 --- /dev/null +++ b/tests/storage/test_client_ips.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import synapse.server +import synapse.storage +import synapse.types +import tests.unittest +import tests.utils + + +class ClientIpStoreTestCase(tests.unittest.TestCase): + def __init__(self, *args, **kwargs): + super(ClientIpStoreTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.DataStore + self.clock = None # type: tests.utils.MockClock + + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def test_insert_new_client_ip(self): + self.clock.now = 12345678 + user_id = "@user:id" + yield self.store.insert_client_ip( + synapse.types.UserID.from_string(user_id), + "access_token", "ip", "user_agent", "device_id", + ) + + # deliberately use an iterable here to make sure that the lookup + # method doesn't iterate it twice + device_list = iter(((user_id, "device_id"),)) + result = yield self.store.get_last_client_ip_by_device(device_list) + + r = result[(user_id, "device_id")] + self.assertDictContainsSubset( + { + "user_id": user_id, + "device_id": "device_id", + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "last_seen": 12345678000, + }, + r + ) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py new file mode 100644 index 0000000000..d3e9d97a9a --- /dev/null +++ b/tests/storage/test_devices.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import synapse.server +import synapse.types +import tests.unittest +import tests.utils + + +class DeviceStoreTestCase(tests.unittest.TestCase): + def __init__(self, *args, **kwargs): + super(DeviceStoreTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.DataStore + + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_store_new_device(self): + yield self.store.store_device( + "user_id", "device_id", "display_name" + ) + + res = yield self.store.get_device("user_id", "device_id") + self.assertDictContainsSubset({ + "user_id": "user_id", + "device_id": "device_id", + "display_name": "display_name", + }, res) + + @defer.inlineCallbacks + def test_get_devices_by_user(self): + yield self.store.store_device( + "user_id", "device1", "display_name 1" + ) + yield self.store.store_device( + "user_id", "device2", "display_name 2" + ) + yield self.store.store_device( + "user_id2", "device3", "display_name 3" + ) + + res = yield self.store.get_devices_by_user("user_id") + self.assertEqual(2, len(res.keys())) + self.assertDictContainsSubset({ + "user_id": "user_id", + "device_id": "device1", + "display_name": "display_name 1", + }, res["device1"]) + self.assertDictContainsSubset({ + "user_id": "user_id", + "device_id": "device2", + "display_name": "display_name 2", + }, res["device2"]) -- cgit 1.5.1 From 40a1c96617fd5926b53f8993bb93af159af4d674 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 20 Jul 2016 18:06:28 +0100 Subject: Fix PEP8 errors --- tests/handlers/test_device.py | 2 +- tests/storage/test_devices.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index c2e12135d6..b05aa9bb55 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -126,4 +126,4 @@ class DeviceTestCase(unittest.TestCase): yield self.store.insert_client_ip( types.UserID.from_string(user_id), access_token, ip, "user_agent", device_id) - self.clock.advance_time(1000) \ No newline at end of file + self.clock.advance_time(1000) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index d3e9d97a9a..a6ce993375 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -15,8 +15,6 @@ from twisted.internet import defer -import synapse.server -import synapse.types import tests.unittest import tests.utils -- cgit 1.5.1 From 406f7aa0f6ca7433e52433485824e80b79930498 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 20 Jul 2016 17:58:44 +0100 Subject: Implement GET /device/{deviceId} --- synapse/handlers/device.py | 46 ++++++++++++++++++++++++++------- synapse/rest/client/v2_alpha/devices.py | 25 ++++++++++++++++++ tests/handlers/test_device.py | 37 +++++++++++++++++++------- 3 files changed, 89 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 6bbbf59e52..3c88be0679 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -12,7 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.errors import StoreError + +from synapse.api import errors from synapse.util import stringutils from twisted.internet import defer from ._base import BaseHandler @@ -65,10 +66,10 @@ class DeviceHandler(BaseHandler): ignore_if_known=False, ) defer.returnValue(device_id) - except StoreError: + except errors.StoreError: attempts += 1 - raise StoreError(500, "Couldn't generate a device ID.") + raise errors.StoreError(500, "Couldn't generate a device ID.") @defer.inlineCallbacks def get_devices_by_user(self, user_id): @@ -88,11 +89,38 @@ class DeviceHandler(BaseHandler): devices=((user_id, device_id) for device_id in devices.keys()) ) - for device_id in devices.keys(): - ip = ips.get((user_id, device_id), {}) - devices[device_id].update({ - "last_seen_ts": ip.get("last_seen"), - "last_seen_ip": ip.get("ip"), - }) + for device in devices.values(): + _update_device_from_client_ips(device, ips) defer.returnValue(devices) + + @defer.inlineCallbacks + def get_device(self, user_id, device_id): + """ Retrieve the given device + + Args: + user_id (str): + device_id (str) + + Returns: + defer.Deferred: dict[str, X]: info on the device + Raises: + errors.NotFoundError: if the device was not found + """ + try: + device = yield self.store.get_device(user_id, device_id) + except errors.StoreError, e: + raise errors.NotFoundError + ips = yield self.store.get_last_client_ip_by_device( + devices=((user_id, device_id),) + ) + _update_device_from_client_ips(device, ips) + defer.returnValue(device) + + +def _update_device_from_client_ips(device, client_ips): + ip = client_ips.get((device["user_id"], device["device_id"]), {}) + device.update({ + "last_seen_ts": ip.get("last_seen"), + "last_seen_ip": ip.get("ip"), + }) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 5cf8bd1afa..8b9ab4f674 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -47,5 +47,30 @@ class DevicesRestServlet(RestServlet): defer.returnValue((200, {"devices": devices})) +class DeviceRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/devices/(?P[^/]*)$", + releases=[], v2_alpha=False) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(DeviceRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def on_GET(self, request, device_id): + requester = yield self.auth.get_user_by_req(request) + device = yield self.device_handler.get_device( + requester.user.to_string(), + device_id, + ) + defer.returnValue((200, device)) + + def register_servlets(hs, http_server): DevicesRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index b05aa9bb55..73f09874d8 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -19,6 +19,8 @@ import synapse.handlers.device import synapse.storage from tests import unittest, utils +user1 = "@boris:aaa" +user2 = "@theresa:bbb" class DeviceTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): @@ -78,16 +80,7 @@ class DeviceTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_devices_by_user(self): - # check this works for both devices which have a recorded client_ip, - # and those which don't. - user1 = "@boris:aaa" - user2 = "@theresa:bbb" - yield self._record_user(user1, "xyz", "display 0") - yield self._record_user(user1, "fco", "display 1", "token1", "ip1") - yield self._record_user(user1, "abc", "display 2", "token2", "ip2") - yield self._record_user(user1, "abc", "display 2", "token3", "ip3") - - yield self._record_user(user2, "def", "dispkay", "token4", "ip4") + yield self._record_users() res = yield self.handler.get_devices_by_user(user1) self.assertEqual(3, len(res.keys())) @@ -113,6 +106,30 @@ class DeviceTestCase(unittest.TestCase): "last_seen_ts": 3000000, }, res["abc"]) + @defer.inlineCallbacks + def test_get_device(self): + yield self._record_users() + + res = yield self.handler.get_device(user1, "abc") + self.assertDictContainsSubset({ + "user_id": user1, + "device_id": "abc", + "display_name": "display 2", + "last_seen_ip": "ip3", + "last_seen_ts": 3000000, + }, res) + + @defer.inlineCallbacks + def _record_users(self): + # check this works for both devices which have a recorded client_ip, + # and those which don't. + yield self._record_user(user1, "xyz", "display 0") + yield self._record_user(user1, "fco", "display 1", "token1", "ip1") + yield self._record_user(user1, "abc", "display 2", "token2", "ip2") + yield self._record_user(user1, "abc", "display 2", "token3", "ip3") + + yield self._record_user(user2, "def", "dispkay", "token4", "ip4") + @defer.inlineCallbacks def _record_user(self, user_id, device_id, display_name, access_token=None, ip=None): -- cgit 1.5.1 From 1c3c202b969d6a7e5e4af2b2dca370f053b92c9f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 21 Jul 2016 13:15:15 +0100 Subject: Fix PEP8 errors --- synapse/handlers/device.py | 2 +- tests/handlers/test_device.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 3c88be0679..110f5fbb5c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -109,7 +109,7 @@ class DeviceHandler(BaseHandler): """ try: device = yield self.store.get_device(user_id, device_id) - except errors.StoreError, e: + except errors.StoreError: raise errors.NotFoundError ips = yield self.store.get_last_client_ip_by_device( devices=((user_id, device_id),) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 73f09874d8..87c3c75aea 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -22,6 +22,7 @@ from tests import unittest, utils user1 = "@boris:aaa" user2 = "@theresa:bbb" + class DeviceTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(DeviceTestCase, self).__init__(*args, **kwargs) -- cgit 1.5.1 From 55abbe1850efff95efe9935873b666e5fc4bf0e9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 21 Jul 2016 15:55:13 +0100 Subject: make /devices return a list Turns out I specced this to return a list of devices rather than a dict of them --- synapse/handlers/device.py | 10 +++++----- tests/handlers/test_device.py | 11 +++++++---- 2 files changed, 12 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 110f5fbb5c..1f9e15c33c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -79,17 +79,17 @@ class DeviceHandler(BaseHandler): Args: user_id (str): Returns: - defer.Deferred: dict[str, dict[str, X]]: map from device_id to - info on the device + defer.Deferred: list[dict[str, X]]: info on each device """ - devices = yield self.store.get_devices_by_user(user_id) + device_map = yield self.store.get_devices_by_user(user_id) ips = yield self.store.get_last_client_ip_by_device( - devices=((user_id, device_id) for device_id in devices.keys()) + devices=((user_id, device_id) for device_id in device_map.keys()) ) - for device in devices.values(): + devices = device_map.values() + for device in devices: _update_device_from_client_ips(device, ips) defer.returnValue(devices) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 87c3c75aea..331aa13fed 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -84,28 +84,31 @@ class DeviceTestCase(unittest.TestCase): yield self._record_users() res = yield self.handler.get_devices_by_user(user1) - self.assertEqual(3, len(res.keys())) + self.assertEqual(3, len(res)) + device_map = { + d["device_id"]: d for d in res + } self.assertDictContainsSubset({ "user_id": user1, "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, "last_seen_ts": None, - }, res["xyz"]) + }, device_map["xyz"]) self.assertDictContainsSubset({ "user_id": user1, "device_id": "fco", "display_name": "display 1", "last_seen_ip": "ip1", "last_seen_ts": 1000000, - }, res["fco"]) + }, device_map["fco"]) self.assertDictContainsSubset({ "user_id": user1, "device_id": "abc", "display_name": "display 2", "last_seen_ip": "ip3", "last_seen_ts": 3000000, - }, res["abc"]) + }, device_map["abc"]) @defer.inlineCallbacks def test_get_device(self): -- cgit 1.5.1 From 465117d7ca40ba9906697aa023897798f7833830 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 25 Jul 2016 12:10:42 +0100 Subject: Fix background_update tests A bit of a cleanup for background_updates, and make sure that the real background updates have run before we start the unit tests, so that they don't interfere with the tests. --- synapse/storage/background_updates.py | 27 ++++++++++++++++++++------- tests/storage/test_background_update.py | 22 ++++++++++++++++------ 2 files changed, 36 insertions(+), 13 deletions(-) (limited to 'tests') diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 75951d0173..2771f7c3c1 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -88,10 +88,12 @@ class BackgroundUpdateStore(SQLBaseStore): @defer.inlineCallbacks def start_doing_background_updates(self): - while True: - if self._background_update_timer is not None: - return + assert(self._background_update_timer is not None, + "background updates already running") + + logger.info("Starting background schema updates") + while True: sleep = defer.Deferred() self._background_update_timer = self._clock.call_later( self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None @@ -102,7 +104,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_timer = None try: - result = yield self.do_background_update( + result = yield self.do_next_background_update( self.BACKGROUND_UPDATE_DURATION_MS ) except: @@ -113,11 +115,12 @@ class BackgroundUpdateStore(SQLBaseStore): "No more background updates to do." " Unscheduling background update task." ) - return + defer.returnValue() @defer.inlineCallbacks - def do_background_update(self, desired_duration_ms): - """Does some amount of work on a background update + def do_next_background_update(self, desired_duration_ms): + """Does some amount of work on the next queued background update + Args: desired_duration_ms(float): How long we want to spend updating. @@ -136,11 +139,21 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_queue.append(update['update_name']) if not self._background_update_queue: + # no work left to do defer.returnValue(None) + # pop from the front, and add back to the back update_name = self._background_update_queue.pop(0) self._background_update_queue.append(update_name) + res = yield self._do_background_update(update_name, desired_duration_ms) + defer.returnValue(res) + + @defer.inlineCallbacks + def _do_background_update(self, update_name, desired_duration_ms): + logger.info("Starting update batch on background update '%s'", + update_name) + update_handler = self._background_update_handlers[update_name] performance = self._background_update_performance.get(update_name) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 6e4d9b1373..4944cb0d2e 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -10,7 +10,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver() # type: synapse.server.HomeServer self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -20,11 +20,20 @@ class BackgroundUpdateTestCase(unittest.TestCase): "test_update", self.update_handler ) + # run the real background updates, to get them out the way + # (perhaps we should run them as part of the test HS setup, since we + # run all of the other schema setup stuff there?) + while True: + res = yield self.store.do_next_background_update(1000) + if res is None: + break + @defer.inlineCallbacks def test_do_background_update(self): desired_count = 1000 duration_ms = 42 + # first step: make a bit of progress @defer.inlineCallbacks def update(progress, count): self.clock.advance_time_msec(count * duration_ms) @@ -42,7 +51,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): yield self.store.start_background_update("test_update", {"my_key": 1}) self.update_handler.reset_mock() - result = yield self.store.do_background_update( + result = yield self.store.do_next_background_update( duration_ms * desired_count ) self.assertIsNotNone(result) @@ -50,24 +59,25 @@ class BackgroundUpdateTestCase(unittest.TestCase): {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE ) + # second step: complete the update @defer.inlineCallbacks def update(progress, count): yield self.store._end_background_update("test_update") defer.returnValue(count) self.update_handler.side_effect = update - self.update_handler.reset_mock() - result = yield self.store.do_background_update( - duration_ms * desired_count + result = yield self.store.do_next_background_update( + duration_ms * desired_count ) self.assertIsNotNone(result) self.update_handler.assert_called_once_with( {"my_key": 2}, desired_count ) + # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = yield self.store.do_background_update( + result = yield self.store.do_next_background_update( duration_ms * desired_count ) self.assertIsNone(result) -- cgit 1.5.1 From f16f0e169d30e6920b892ee772693199b16713fd Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 25 Jul 2016 12:12:47 +0100 Subject: Slightly saner logging for unittests 1. Give the handler used for logging in unit tests a formatter, so that the output is slightly more meaningful 2. Log some synapse.storage stuff, because it's useful. --- tests/unittest.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/unittest.py b/tests/unittest.py index 5b22abfe74..38715972dd 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -17,13 +17,18 @@ from twisted.trial import unittest import logging - # logging doesn't have a "don't log anything at all EVARRRR setting, # but since the highest value is 50, 1000000 should do ;) NEVER = 1000000 -logging.getLogger().addHandler(logging.StreamHandler()) +handler = logging.StreamHandler() +handler.setFormatter(logging.Formatter( + "%(levelname)s:%(name)s:%(message)s [%(pathname)s:%(lineno)d]" +)) +logging.getLogger().addHandler(handler) logging.getLogger().setLevel(NEVER) +logging.getLogger("synapse.storage.SQL").setLevel(NEVER) +logging.getLogger("synapse.storage.txn").setLevel(NEVER) def around(target): @@ -70,8 +75,6 @@ class TestCase(unittest.TestCase): return ret logging.getLogger().setLevel(level) - # Don't set SQL logging - logging.getLogger("synapse.storage").setLevel(old_level) return orig() def assertObjectHasAttributes(self, attrs, obj): -- cgit 1.5.1 From 42f4feb2b709671bb2dbbabfe1aad7e951479652 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 25 Jul 2016 12:25:06 +0100 Subject: PEP8 --- tests/storage/test_background_update.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 4944cb0d2e..1286b4ce2d 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -68,7 +68,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): self.update_handler.side_effect = update self.update_handler.reset_mock() result = yield self.store.do_next_background_update( - duration_ms * desired_count + duration_ms * desired_count ) self.assertIsNotNone(result) self.update_handler.assert_called_once_with( -- cgit 1.5.1 From 436bffd15fb8382a0d2dddd3c6f7a077ba751da2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 22 Jul 2016 14:52:53 +0100 Subject: Implement deleting devices --- synapse/handlers/auth.py | 22 ++++++++++++++++-- synapse/handlers/device.py | 27 +++++++++++++++++++++- synapse/rest/client/v1/login.py | 13 ++++++++--- synapse/rest/client/v2_alpha/devices.py | 14 +++++++++++ synapse/rest/client/v2_alpha/register.py | 10 ++++---- synapse/storage/devices.py | 15 ++++++++++++ synapse/storage/registration.py | 26 +++++++++++++++++---- .../schema/delta/33/access_tokens_device_index.sql | 17 ++++++++++++++ .../schema/delta/33/refreshtoken_device_index.sql | 17 ++++++++++++++ tests/handlers/test_device.py | 22 ++++++++++++++++-- tests/rest/client/v2_alpha/test_register.py | 14 +++++++---- 11 files changed, 176 insertions(+), 21 deletions(-) create mode 100644 synapse/storage/schema/delta/33/access_tokens_device_index.sql create mode 100644 synapse/storage/schema/delta/33/refreshtoken_device_index.sql (limited to 'tests') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d5d2072436..2e138f328f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -77,6 +77,7 @@ class AuthHandler(BaseHandler): self.ldap_bind_password = hs.config.ldap_bind_password self.hs = hs # FIXME better possibility to access registrationHandler later? + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -374,7 +375,8 @@ class AuthHandler(BaseHandler): return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id, device_id=None): + def get_login_tuple_for_user_id(self, user_id, device_id=None, + initial_display_name=None): """ Gets login tuple for the user with the given user ID. @@ -383,9 +385,15 @@ class AuthHandler(BaseHandler): The user is assumed to have been authenticated by some other machanism (e.g. CAS), and the user_id converted to the canonical case. + The device will be recorded in the table if it is not there already. + Args: user_id (str): canonical User ID - device_id (str): the device ID to associate with the access token + device_id (str|None): the device ID to associate with the tokens. + None to leave the tokens unassociated with a device (deprecated: + we should always have a device ID) + initial_display_name (str): display name to associate with the + device if it needs re-registering Returns: A tuple of: The access token for the user's session. @@ -397,6 +405,16 @@ class AuthHandler(BaseHandler): logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) refresh_token = yield self.issue_refresh_token(user_id, device_id) + + # the device *should* have been registered before we got here; however, + # it's possible we raced against a DELETE operation. The thing we + # really don't want is active access_tokens without a record of the + # device, so we double-check it here. + if device_id is not None: + yield self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 1f9e15c33c..a7a192e1c9 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -100,7 +100,7 @@ class DeviceHandler(BaseHandler): Args: user_id (str): - device_id (str) + device_id (str): Returns: defer.Deferred: dict[str, X]: info on the device @@ -117,6 +117,31 @@ class DeviceHandler(BaseHandler): _update_device_from_client_ips(device, ips) defer.returnValue(device) + @defer.inlineCallbacks + def delete_device(self, user_id, device_id): + """ Delete the given device + + Args: + user_id (str): + device_id (str): + + Returns: + defer.Deferred: + """ + + try: + yield self.store.delete_device(user_id, device_id) + except errors.StoreError, e: + if e.code == 404: + # no match + pass + else: + raise + + yield self.store.user_delete_access_tokens(user_id, + device_id=device_id) + + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index e8b791519c..92fcae674a 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -152,7 +152,10 @@ class LoginRestServlet(ClientV1RestServlet): ) device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { "user_id": user_id, # may have changed @@ -173,7 +176,10 @@ class LoginRestServlet(ClientV1RestServlet): ) device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { "user_id": user_id, # may have changed @@ -262,7 +268,8 @@ class LoginRestServlet(ClientV1RestServlet): ) access_token, refresh_token = ( yield auth_handler.get_login_tuple_for_user_id( - registered_user_id, device_id + registered_user_id, device_id, + login_submission.get("initial_device_display_name") ) ) result = { diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 8b9ab4f674..30ef8b3da9 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -70,6 +70,20 @@ class DeviceRestServlet(RestServlet): ) defer.returnValue((200, device)) + @defer.inlineCallbacks + def on_DELETE(self, request, device_id): + # XXX: it's not completely obvious we want to expose this endpoint. + # It allows the client to delete access tokens, which feels like a + # thing which merits extra auth. But if we want to do the interactive- + # auth dance, we should really make it possible to delete more than one + # device at a time. + requester = yield self.auth.get_user_by_req(request) + yield self.device_handler.delete_device( + requester.user.to_string(), + device_id, + ) + defer.returnValue((200, {})) + def register_servlets(hs, http_server): DevicesRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index c8c9395fc6..9f599ea8bb 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -374,13 +374,13 @@ class RegisterRestServlet(RestServlet): """ device_id = yield self._register_device(user_id, params) - access_token = yield self.auth_handler.issue_access_token( - user_id, device_id=device_id + access_token, refresh_token = ( + yield self.auth_handler.get_login_tuple_for_user_id( + user_id, device_id=device_id, + initial_display_name=params.get("initial_device_display_name") + ) ) - refresh_token = yield self.auth_handler.issue_refresh_token( - user_id, device_id=device_id - ) defer.returnValue({ "user_id": user_id, "access_token": access_token, diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 1cc6e07f2b..4689980f80 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -76,6 +76,21 @@ class DeviceStore(SQLBaseStore): desc="get_device", ) + def delete_device(self, user_id, device_id): + """Delete a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred + """ + return self._simple_delete_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="delete_device", + ) + @defer.inlineCallbacks def get_devices_by_user(self, user_id): """Retrieve all of a user's registered devices. diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 9a92b35361..935e82bf7a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -18,18 +18,31 @@ import re from twisted.internet import defer from synapse.api.errors import StoreError, Codes - -from ._base import SQLBaseStore +from synapse.storage import background_updates from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -class RegistrationStore(SQLBaseStore): +class RegistrationStore(background_updates.BackgroundUpdateStore): def __init__(self, hs): super(RegistrationStore, self).__init__(hs) self.clock = hs.get_clock() + self.register_background_index_update( + "access_tokens_device_index", + index_name="access_tokens_device_id", + table="access_tokens", + columns=["user_id", "device_id"], + ) + + self.register_background_index_update( + "refresh_tokens_device_index", + index_name="refresh_tokens_device_id", + table="refresh_tokens", + columns=["user_id", "device_id"], + ) + @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. @@ -238,11 +251,16 @@ class RegistrationStore(SQLBaseStore): self.get_user_by_id.invalidate((user_id,)) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id, except_token_ids=[]): + def user_delete_access_tokens(self, user_id, except_token_ids=[], + device_id=None): def f(txn): sql = "SELECT token FROM access_tokens WHERE user_id = ?" clauses = [user_id] + if device_id is not None: + sql += " AND device_id = ?" + clauses.append(device_id) + if except_token_ids: sql += " AND id NOT IN (%s)" % ( ",".join(["?" for _ in except_token_ids]), diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/schema/delta/33/access_tokens_device_index.sql new file mode 100644 index 0000000000..61ad3fe3e8 --- /dev/null +++ b/synapse/storage/schema/delta/33/access_tokens_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('access_tokens_device_index', '{}'); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql new file mode 100644 index 0000000000..bb225dafbf --- /dev/null +++ b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('refresh_tokens_device_index', '{}'); diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 331aa13fed..214e722eb3 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -12,11 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse import types + from twisted.internet import defer +import synapse.api.errors import synapse.handlers.device + import synapse.storage +from synapse import types from tests import unittest, utils user1 = "@boris:aaa" @@ -27,7 +30,7 @@ class DeviceTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(DeviceTestCase, self).__init__(*args, **kwargs) self.store = None # type: synapse.storage.DataStore - self.handler = None # type: device.DeviceHandler + self.handler = None # type: synapse.handlers.device.DeviceHandler self.clock = None # type: utils.MockClock @defer.inlineCallbacks @@ -123,6 +126,21 @@ class DeviceTestCase(unittest.TestCase): "last_seen_ts": 3000000, }, res) + @defer.inlineCallbacks + def test_delete_device(self): + yield self._record_users() + + # delete the device + yield self.handler.delete_device(user1, "abc") + + # check the device was deleted + with self.assertRaises(synapse.api.errors.NotFoundError): + yield self.handler.get_device(user1, "abc") + + # we'd like to check the access token was invalidated, but that's a + # bit of a PITA. + + @defer.inlineCallbacks def _record_users(self): # check this works for both devices which have a recorded client_ip, diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 3bd7065e32..8ac56a1fb2 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -65,13 +65,16 @@ class RegisterRestServletTestCase(unittest.TestCase): self.registration_handler.appservice_register = Mock( return_value=user_id ) - self.auth_handler.issue_access_token = Mock(return_value=token) + self.auth_handler.get_login_tuple_for_user_id = Mock( + return_value=(token, "kermits_refresh_token") + ) (code, result) = yield self.servlet.on_POST(self.request) self.assertEquals(code, 200) det_data = { "user_id": user_id, "access_token": token, + "refresh_token": "kermits_refresh_token", "home_server": self.hs.hostname } self.assertDictContainsSubset(det_data, result) @@ -121,7 +124,9 @@ class RegisterRestServletTestCase(unittest.TestCase): "password": "monkey" }, None) self.registration_handler.register = Mock(return_value=(user_id, None)) - self.auth_handler.issue_access_token = Mock(return_value=token) + self.auth_handler.get_login_tuple_for_user_id = Mock( + return_value=(token, "kermits_refresh_token") + ) self.device_handler.check_device_registered = \ Mock(return_value=device_id) @@ -130,13 +135,14 @@ class RegisterRestServletTestCase(unittest.TestCase): det_data = { "user_id": user_id, "access_token": token, + "refresh_token": "kermits_refresh_token", "home_server": self.hs.hostname, "device_id": device_id, } self.assertDictContainsSubset(det_data, result) self.assertIn("refresh_token", result) - self.auth_handler.issue_access_token.assert_called_once_with( - user_id, device_id=device_id) + self.auth_handler.get_login_tuple_for_user_id( + user_id, device_id=device_id, initial_device_display_name=None) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False -- cgit 1.5.1 From 012b4c19132d57fdbc1b6b0e304eb60eaf19200f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 25 Jul 2016 17:51:24 +0100 Subject: Implement updating devices You can update the displayname of devices now. --- synapse/handlers/device.py | 24 ++++++++++++++++++++++ synapse/rest/client/v2_alpha/devices.py | 24 +++++++++++++++------- synapse/storage/devices.py | 27 ++++++++++++++++++++++++- tests/handlers/test_device.py | 16 +++++++++++++++ tests/storage/test_devices.py | 36 +++++++++++++++++++++++++++++++++ 5 files changed, 119 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index a7a192e1c9..9e65d85e6d 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -141,6 +141,30 @@ class DeviceHandler(BaseHandler): yield self.store.user_delete_access_tokens(user_id, device_id=device_id) + @defer.inlineCallbacks + def update_device(self, user_id, device_id, content): + """ Update the given device + + Args: + user_id (str): + device_id (str): + content (dict): body of update request + + Returns: + defer.Deferred: + """ + + try: + yield self.store.update_device( + user_id, + device_id, + new_display_name=content.get("display_name") + ) + except errors.StoreError, e: + if e.code == 404: + raise errors.NotFoundError() + else: + raise def _update_device_from_client_ips(device, client_ips): diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 30ef8b3da9..8fbd3d3dfc 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -13,19 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import logging -from synapse.http.servlet import RestServlet +from twisted.internet import defer +from synapse.http import servlet from ._base import client_v2_patterns -import logging - - logger = logging.getLogger(__name__) -class DevicesRestServlet(RestServlet): +class DevicesRestServlet(servlet.RestServlet): PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) def __init__(self, hs): @@ -47,7 +45,7 @@ class DevicesRestServlet(RestServlet): defer.returnValue((200, {"devices": devices})) -class DeviceRestServlet(RestServlet): +class DeviceRestServlet(servlet.RestServlet): PATTERNS = client_v2_patterns("/devices/(?P[^/]*)$", releases=[], v2_alpha=False) @@ -84,6 +82,18 @@ class DeviceRestServlet(RestServlet): ) defer.returnValue((200, {})) + @defer.inlineCallbacks + def on_PUT(self, request, device_id): + requester = yield self.auth.get_user_by_req(request) + + body = servlet.parse_json_object_from_request(request) + yield self.device_handler.update_device( + requester.user.to_string(), + device_id, + body + ) + defer.returnValue((200, {})) + def register_servlets(hs, http_server): DevicesRestServlet(hs).register(http_server) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 4689980f80..afd6530cab 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -81,7 +81,7 @@ class DeviceStore(SQLBaseStore): Args: user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to retrieve + device_id (str): The ID of the device to delete Returns: defer.Deferred """ @@ -91,6 +91,31 @@ class DeviceStore(SQLBaseStore): desc="delete_device", ) + def update_device(self, user_id, device_id, new_display_name=None): + """Update a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to update + new_display_name (str|None): new displayname for device; None + to leave unchanged + Raises: + StoreError: if the device is not found + Returns: + defer.Deferred + """ + updates = {} + if new_display_name is not None: + updates["display_name"] = new_display_name + if not updates: + return defer.succeed(None) + return self._simple_update_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + updatevalues=updates, + desc="update_device", + ) + @defer.inlineCallbacks def get_devices_by_user(self, user_id): """Retrieve all of a user's registered devices. diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 214e722eb3..85a970a6c9 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -140,6 +140,22 @@ class DeviceTestCase(unittest.TestCase): # we'd like to check the access token was invalidated, but that's a # bit of a PITA. + @defer.inlineCallbacks + def test_update_device(self): + yield self._record_users() + + update = {"display_name": "new display"} + yield self.handler.update_device(user1, "abc", update) + + res = yield self.handler.get_device(user1, "abc") + self.assertEqual(res["display_name"], "new display") + + @defer.inlineCallbacks + def test_update_unknown_device(self): + update = {"display_name": "new_display"} + with self.assertRaises(synapse.api.errors.NotFoundError): + yield self.handler.update_device("user_id", "unknown_device_id", + update) @defer.inlineCallbacks def _record_users(self): diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index a6ce993375..f8725acea0 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -15,6 +15,7 @@ from twisted.internet import defer +import synapse.api.errors import tests.unittest import tests.utils @@ -67,3 +68,38 @@ class DeviceStoreTestCase(tests.unittest.TestCase): "device_id": "device2", "display_name": "display_name 2", }, res["device2"]) + + @defer.inlineCallbacks + def test_update_device(self): + yield self.store.store_device( + "user_id", "device_id", "display_name 1" + ) + + res = yield self.store.get_device("user_id", "device_id") + self.assertEqual("display_name 1", res["display_name"]) + + # do a no-op first + yield self.store.update_device( + "user_id", "device_id", + ) + res = yield self.store.get_device("user_id", "device_id") + self.assertEqual("display_name 1", res["display_name"]) + + # do the update + yield self.store.update_device( + "user_id", "device_id", + new_display_name="display_name 2", + ) + + # check it worked + res = yield self.store.get_device("user_id", "device_id") + self.assertEqual("display_name 2", res["display_name"]) + + @defer.inlineCallbacks + def test_update_unknown_device(self): + with self.assertRaises(synapse.api.errors.StoreError) as cm: + yield self.store.update_device( + "user_id", "unknown_device_id", + new_display_name="display_name 2", + ) + self.assertEqual(404, cm.exception.code) -- cgit 1.5.1 From 8e0249416643f20f0c4cd8f2e19cf45ea63289d3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 26 Jul 2016 11:09:47 +0100 Subject: Delete refresh tokens when deleting devices --- synapse/handlers/device.py | 6 ++-- synapse/storage/registration.py | 58 +++++++++++++++++++++++++++++--------- tests/storage/test_registration.py | 34 ++++++++++++++++++++++ 3 files changed, 83 insertions(+), 15 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 9e65d85e6d..eaead50800 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -138,8 +138,10 @@ class DeviceHandler(BaseHandler): else: raise - yield self.store.user_delete_access_tokens(user_id, - device_id=device_id) + yield self.store.user_delete_access_tokens( + user_id, device_id=device_id, + delete_refresh_tokens=True, + ) @defer.inlineCallbacks def update_device(self, user_id, device_id, content): diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 935e82bf7a..d9555e073a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -252,20 +252,36 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): @defer.inlineCallbacks def user_delete_access_tokens(self, user_id, except_token_ids=[], - device_id=None): - def f(txn): - sql = "SELECT token FROM access_tokens WHERE user_id = ?" + device_id=None, + delete_refresh_tokens=False): + """ + Invalidate access/refresh tokens belonging to a user + + Args: + user_id (str): ID of user the tokens belong to + except_token_ids (list[str]): list of access_tokens which should + *not* be deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + delete_refresh_tokens (bool): True to delete refresh tokens as + well as access tokens. + Returns: + defer.Deferred: + """ + def f(txn, table, except_tokens, call_after_delete): + sql = "SELECT token FROM %s WHERE user_id = ?" % table clauses = [user_id] if device_id is not None: sql += " AND device_id = ?" clauses.append(device_id) - if except_token_ids: + if except_tokens: sql += " AND id NOT IN (%s)" % ( - ",".join(["?" for _ in except_token_ids]), + ",".join(["?" for _ in except_tokens]), ) - clauses += except_token_ids + clauses += except_tokens txn.execute(sql, clauses) @@ -274,16 +290,33 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): n = 100 chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] for chunk in chunks: - for row in chunk: - txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) + if call_after_delete: + for row in chunk: + txn.call_after(call_after_delete, (row[0],)) txn.execute( - "DELETE FROM access_tokens WHERE token in (%s)" % ( + "DELETE FROM %s WHERE token in (%s)" % ( + table, ",".join(["?" for _ in chunk]), ), [r[0] for r in chunk] ) - yield self.runInteraction("user_delete_access_tokens", f) + # delete refresh tokens first, to stop new access tokens being + # allocated while our backs are turned + if delete_refresh_tokens: + yield self.runInteraction( + "user_delete_access_tokens", f, + table="refresh_tokens", + except_tokens=[], + call_after_delete=None, + ) + + yield self.runInteraction( + "user_delete_access_tokens", f, + table="access_tokens", + except_tokens=except_token_ids, + call_after_delete=self.get_user_by_access_token.invalidate, + ) def delete_access_token(self, access_token): def f(txn): @@ -306,9 +339,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): Args: token (str): The access token of a user. Returns: - dict: Including the name (user_id) and the ID of their access token. - Raises: - StoreError if no user was found. + defer.Deferred: None, if the token did not match, ootherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`. """ return self.runInteraction( "get_user_by_access_token", diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index b03ca303a2..f7d74dea8e 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -128,6 +128,40 @@ class RegistrationStoreTestCase(unittest.TestCase): with self.assertRaises(StoreError): yield self.store.exchange_refresh_token(last_token, generator.generate) + @defer.inlineCallbacks + def test_user_delete_access_tokens(self): + # add some tokens + generator = TokenGenerator() + refresh_token = generator.generate(self.user_id) + yield self.store.register(self.user_id, self.tokens[0], self.pwhash) + yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], + self.device_id) + yield self.store.add_refresh_token_to_user(self.user_id, refresh_token, + self.device_id) + + # now delete some + yield self.store.user_delete_access_tokens( + self.user_id, device_id=self.device_id, delete_refresh_tokens=True) + + # check they were deleted + user = yield self.store.get_user_by_access_token(self.tokens[1]) + self.assertIsNone(user, "access token was not deleted by device_id") + with self.assertRaises(StoreError): + yield self.store.exchange_refresh_token(refresh_token, + generator.generate) + + # check the one not associated with the device was not deleted + user = yield self.store.get_user_by_access_token(self.tokens[0]) + self.assertEqual(self.user_id, user["name"]) + + # now delete the rest + yield self.store.user_delete_access_tokens( + self.user_id, delete_refresh_tokens=True) + + user = yield self.store.get_user_by_access_token(self.tokens[0]) + self.assertIsNone(user, + "access token was not deleted without device_id") + class TokenGenerator: def __init__(self): -- cgit 1.5.1 From eb359eced44407b1ee9648f10fdf3df63c8d40ad Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 26 Jul 2016 16:46:53 +0100 Subject: Add `create_requester` function Wrap the `Requester` constructor with a function which provides sensible defaults, and use it throughout --- synapse/api/auth.py | 24 +++++++++++------------- synapse/handlers/_base.py | 13 +++++++------ synapse/handlers/profile.py | 12 +++++++----- synapse/handlers/register.py | 16 +++++++++------- synapse/handlers/room_member.py | 20 +++++++++----------- synapse/rest/client/v2_alpha/keys.py | 10 ++++------ synapse/types.py | 33 ++++++++++++++++++++++++++++++++- tests/handlers/test_profile.py | 10 ++++++---- tests/replication/test_resource.py | 20 +++++++++++--------- tests/rest/client/v1/test_profile.py | 13 +++++-------- tests/utils.py | 5 ----- 11 files changed, 101 insertions(+), 75 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index eca8513905..eecf3b0b2a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -13,22 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + +import pymacaroons from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json, SignatureVerifyException - from twisted.internet import defer +from unpaddedbase64 import decode_base64 +import synapse.types from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError -from synapse.types import Requester, UserID, get_domain_from_id -from synapse.util.logutils import log_function +from synapse.types import UserID, get_domain_from_id from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.logutils import log_function from synapse.util.metrics import Measure -from unpaddedbase64 import decode_base64 - -import logging -import pymacaroons logger = logging.getLogger(__name__) @@ -566,8 +566,7 @@ class Auth(object): Args: request - An HTTP request with an access_token query parameter. Returns: - defer.Deferred: resolves to a namedtuple including "user" (UserID) - "access_token_id" (int), "is_guest" (bool) + defer.Deferred: resolves to a ``synapse.types.Requester`` object Raises: AuthError if no user by that token exists or the token is invalid. """ @@ -576,9 +575,7 @@ class Auth(object): user_id = yield self._get_appservice_user_id(request.args) if user_id: request.authenticated_entity = user_id - defer.returnValue( - Requester(UserID.from_string(user_id), "", False) - ) + defer.returnValue(synapse.types.create_requester(user_id)) access_token = request.args["access_token"][0] user_info = yield self.get_user_by_access_token(access_token, rights) @@ -612,7 +609,8 @@ class Auth(object): request.authenticated_entity = user.to_string() - defer.returnValue(Requester(user, token_id, is_guest)) + defer.returnValue(synapse.types.create_requester( + user, token_id, is_guest, device_id)) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 6264aa0d9a..11081a0cd5 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -from synapse.api.errors import LimitExceededError +import synapse.types from synapse.api.constants import Membership, EventTypes -from synapse.types import UserID, Requester - - -import logging +from synapse.api.errors import LimitExceededError +from synapse.types import UserID logger = logging.getLogger(__name__) @@ -124,7 +124,8 @@ class BaseHandler(object): # and having homeservers have their own users leave keeps more # of that decision-making and control local to the guest-having # homeserver. - requester = Requester(target_user, "", True) + requester = synapse.types.create_requester( + target_user, is_guest=True) handler = self.hs.get_handlers().room_member_handler yield handler.update_membership( requester, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 711a6a567f..d9ac09078d 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -13,15 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer +import synapse.types from synapse.api.errors import SynapseError, AuthError, CodeMessageException -from synapse.types import UserID, Requester - +from synapse.types import UserID from ._base import BaseHandler -import logging - logger = logging.getLogger(__name__) @@ -165,7 +165,9 @@ class ProfileHandler(BaseHandler): try: # Assume the user isn't a guest because we don't let guests set # profile or avatar data. - requester = Requester(user, "", False) + # XXX why are we recreating `requester` here for each room? + # what was wrong with the `requester` we were passed? + requester = synapse.types.create_requester(user) yield handler.update_membership( requester, user, diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 94b19d0cb0..b9b5880d64 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -14,18 +14,19 @@ # limitations under the License. """Contains functions for registering clients.""" +import logging +import urllib + from twisted.internet import defer -from synapse.types import UserID, Requester +import synapse.types from synapse.api.errors import ( AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError ) -from ._base import BaseHandler -from synapse.util.async import run_on_reactor from synapse.http.client import CaptchaServerHttpClient - -import logging -import urllib +from synapse.types import UserID +from synapse.util.async import run_on_reactor +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -410,8 +411,9 @@ class RegistrationHandler(BaseHandler): if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) profile_handler = self.hs.get_handlers().profile_handler + requester = synapse.types.create_requester(user) yield profile_handler.set_displayname( - user, Requester(user, token, False), displayname + user, requester, displayname ) defer.returnValue((user_id, token)) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7e616f44fd..8cec8fc4ed 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -14,24 +14,22 @@ # limitations under the License. -from twisted.internet import defer +import logging -from ._base import BaseHandler +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json +from twisted.internet import defer +from unpaddedbase64 import decode_base64 -from synapse.types import UserID, RoomID, Requester +import synapse.types from synapse.api.constants import ( EventTypes, Membership, ) from synapse.api.errors import AuthError, SynapseError, Codes +from synapse.types import UserID, RoomID from synapse.util.async import Linearizer from synapse.util.distributor import user_left_room, user_joined_room - -from signedjson.sign import verify_signed_json -from signedjson.key import decode_verify_key_bytes - -from unpaddedbase64 import decode_base64 - -import logging +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler): ) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) else: - requester = Requester(target_user, None, False) + requester = synapse.types.create_requester(target_user) message_handler = self.hs.get_handlers().message_handler prev_event = message_handler.deduplicate_state_event(event, context) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 89ab39491c..56364af337 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -13,18 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + +import simplejson as json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID - -from canonicaljson import encode_canonical_json - from ._base import client_v2_patterns -import logging -import simplejson as json - logger = logging.getLogger(__name__) diff --git a/synapse/types.py b/synapse/types.py index f639651a73..5349b0c450 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -18,7 +18,38 @@ from synapse.api.errors import SynapseError from collections import namedtuple -Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"]) +Requester = namedtuple("Requester", + ["user", "access_token_id", "is_guest", "device_id"]) +""" +Represents the user making a request + +Attributes: + user (UserID): id of the user making the request + access_token_id (int|None): *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest (bool): True if the user making this request is a guest user + device_id (str|None): device_id which was set at authentication time +""" + + +def create_requester(user_id, access_token_id=None, is_guest=False, + device_id=None): + """ + Create a new ``Requester`` object + + Args: + user_id (str|UserID): id of the user making the request + access_token_id (int|None): *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest (bool): True if the user making this request is a guest user + device_id (str|None): device_id which was set at authentication time + + Returns: + Requester + """ + if not isinstance(user_id, UserID): + user_id = UserID.from_string(user_id) + return Requester(user_id, access_token_id, is_guest, device_id) def get_domain_from_id(string): diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 4f2c14e4ff..f1f664275f 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -19,11 +19,12 @@ from twisted.internet import defer from mock import Mock, NonCallableMock +import synapse.types from synapse.api.errors import AuthError from synapse.handlers.profile import ProfileHandler from synapse.types import UserID -from tests.utils import setup_test_homeserver, requester_for_user +from tests.utils import setup_test_homeserver class ProfileHandlers(object): @@ -86,7 +87,7 @@ class ProfileTestCase(unittest.TestCase): def test_set_my_name(self): yield self.handler.set_displayname( self.frank, - requester_for_user(self.frank), + synapse.types.create_requester(self.frank), "Frank Jr." ) @@ -99,7 +100,7 @@ class ProfileTestCase(unittest.TestCase): def test_set_my_name_noauth(self): d = self.handler.set_displayname( self.frank, - requester_for_user(self.bob), + synapse.types.create_requester(self.bob), "Frank Jr." ) @@ -144,7 +145,8 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar(self): yield self.handler.set_avatar_url( - self.frank, requester_for_user(self.frank), "http://my.server/pic.gif" + self.frank, synapse.types.create_requester(self.frank), + "http://my.server/pic.gif" ) self.assertEquals( diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index 842e3d29d7..e70ac6f14d 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.resource import ReplicationResource -from synapse.types import Requester, UserID +import contextlib +import json +from mock import Mock, NonCallableMock from twisted.internet import defer + +import synapse.types +from synapse.replication.resource import ReplicationResource +from synapse.types import UserID from tests import unittest -from tests.utils import setup_test_homeserver, requester_for_user -from mock import Mock, NonCallableMock -import json -import contextlib +from tests.utils import setup_test_homeserver class ReplicationResourceCase(unittest.TestCase): @@ -61,7 +63,7 @@ class ReplicationResourceCase(unittest.TestCase): def test_events_and_state(self): get = self.get(events="-1", state="-1", timeout="0") yield self.hs.get_handlers().room_creation_handler.create_room( - Requester(self.user, "", False), {} + synapse.types.create_requester(self.user), {} ) code, body = yield get self.assertEquals(code, 200) @@ -144,7 +146,7 @@ class ReplicationResourceCase(unittest.TestCase): def send_text_message(self, room_id, message): handler = self.hs.get_handlers().message_handler event = yield handler.create_and_send_nonmember_event( - requester_for_user(self.user), + synapse.types.create_requester(self.user), { "type": "m.room.message", "content": {"body": "message", "msgtype": "m.text"}, @@ -157,7 +159,7 @@ class ReplicationResourceCase(unittest.TestCase): @defer.inlineCallbacks def create_room(self): result = yield self.hs.get_handlers().room_creation_handler.create_room( - Requester(self.user, "", False), {} + synapse.types.create_requester(self.user), {} ) defer.returnValue(result["room_id"]) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index af02fce8fb..1e95e97538 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -14,17 +14,14 @@ # limitations under the License. """Tests REST events for /profile paths.""" -from tests import unittest -from twisted.internet import defer - from mock import Mock +from twisted.internet import defer -from ....utils import MockHttpResource, setup_test_homeserver - +import synapse.types from synapse.api.errors import SynapseError, AuthError -from synapse.types import Requester, UserID - from synapse.rest.client.v1 import profile +from tests import unittest +from ....utils import MockHttpResource, setup_test_homeserver myid = "@1234ABCD:test" PATH_PREFIX = "/_matrix/client/api/v1" @@ -52,7 +49,7 @@ class ProfileTestCase(unittest.TestCase): ) def _get_user_by_req(request=None, allow_guest=False): - return Requester(UserID.from_string(myid), "", False) + return synapse.types.create_requester(myid) hs.get_v1auth().get_user_by_req = _get_user_by_req diff --git a/tests/utils.py b/tests/utils.py index ed547bc39b..915b934e94 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,7 +20,6 @@ from synapse.storage.prepare_database import prepare_database from synapse.storage.engines import create_engine from synapse.server import HomeServer from synapse.federation.transport import server -from synapse.types import Requester from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.logcontext import LoggingContext @@ -512,7 +511,3 @@ class DeferredMockCallable(object): "call(%s)" % _format_call(c[0], c[1]) for c in calls ]) ) - - -def requester_for_user(user): - return Requester(user, None, False) -- cgit 1.5.1 From 0a7d3cd00f8b7e3ad0ba458c3ab9b40a2496545b Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 28 Jul 2016 20:24:24 +0100 Subject: Create separate methods for getting messages to push for the email and http pushers rather than trying to make a single method that will work with their conflicting requirements. The http pusher needs to get the messages in ascending stream order, and doesn't want to miss a message. The email pusher needs to get the messages in descending timestamp order, and doesn't mind if it misses messages. --- synapse/push/emailpusher.py | 5 +- synapse/push/httppusher.py | 3 +- synapse/replication/slave/storage/events.py | 7 +- synapse/storage/event_push_actions.py | 199 +++++++++++++++++++++------- tests/storage/test_event_push_actions.py | 41 ++++++ 5 files changed, 204 insertions(+), 51 deletions(-) create mode 100644 tests/storage/test_event_push_actions.py (limited to 'tests') diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 12a3ec7fd8..e224b68291 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -140,9 +140,8 @@ class EmailPusher(object): being run. """ start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering - unprocessed = yield self.store.get_unread_push_actions_for_user_in_range( - self.user_id, start, self.max_stream_ordering - ) + fn = self.store.get_unread_push_actions_for_user_in_range_for_email + unprocessed = yield fn(self.user_id, start, self.max_stream_ordering) soonest_due_at = None diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 2acc6cc214..9a7db61220 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -141,7 +141,8 @@ class HttpPusher(object): run once per pusher. """ - unprocessed = yield self.store.get_unread_push_actions_for_user_in_range( + fn = self.store.get_unread_push_actions_for_user_in_range_for_http + unprocessed = yield fn( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 369d839464..6a644f1386 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -93,8 +93,11 @@ class SlavedEventStore(BaseSlavedStore): StreamStore.__dict__["get_recent_event_ids_for_room"] ) - get_unread_push_actions_for_user_in_range = ( - DataStore.get_unread_push_actions_for_user_in_range.__func__ + get_unread_push_actions_for_user_in_range_for_http = ( + DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ + ) + get_unread_push_actions_for_user_in_range_for_email = ( + DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__ ) get_push_action_users_in_range = ( DataStore.get_push_action_users_in_range.__func__ diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 958dbcc22b..5ab362bef2 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -117,40 +117,149 @@ class EventPushActionsStore(SQLBaseStore): defer.returnValue(ret) @defer.inlineCallbacks - def get_unread_push_actions_for_user_in_range(self, user_id, - min_stream_ordering, - max_stream_ordering, - limit=20): + def get_unread_push_actions_for_user_in_range_for_http( + self, user_id, min_stream_ordering, max_stream_ordering, limit=20 + ): """Get a list of the most recent unread push actions for a given user, - within the given stream ordering range. + within the given stream ordering range. Called by the httppusher. Args: - user_id (str) - min_stream_ordering - max_stream_ordering - limit (int) + user_id (str): The user to fetch push actions for. + min_stream_ordering(int): The exclusive lower bound on the + stream ordering of event push actions to fetch. + max_stream_ordering(int): The inclusive upper bound on the + stream ordering of event push actions to fetch. + limit (int): The maximum number of rows to return. + Returns: + A promise which resolves to a list of dicts with the keys "event_id", + "room_id", "stream_ordering", "actions". + The list will be ordered by ascending stream_ordering. + The list will have between 0~limit entries. + """ + # find rooms that have a read receipt in them and return the next + # push actions + def get_after_receipt(txn): + # find rooms that have a read receipt in them and return the next + # push actions + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions" + " FROM (" + " SELECT room_id," + " MAX(topological_ordering) as topological_ordering," + " MAX(stream_ordering) as stream_ordering" + " FROM events" + " INNER JOIN receipts_linearized USING (room_id, event_id)" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + ") AS rl," + " event_push_actions AS ep" + " WHERE" + " ep.room_id = rl.room_id" + " AND (" + " ep.topological_ordering > rl.topological_ordering" + " OR (" + " ep.topological_ordering = rl.topological_ordering" + " AND ep.stream_ordering > rl.stream_ordering" + " )" + " )" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering ASC LIMIT ?" + ) + args = [ + user_id, user_id, + min_stream_ordering, max_stream_ordering, limit, + ] + txn.execute(sql, args) + return txn.fetchall() + after_read_receipt = yield self.runInteraction( + "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt + ) + + # There are rooms with push actions in them but you don't have a read receipt in + # them e.g. rooms you've been invited to, so get push actions for rooms which do + # not have read receipts in them too. + def get_no_receipt(txn): + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " e.received_ts" + " FROM event_push_actions AS ep" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE" + " ep.room_id NOT IN (" + " SELECT room_id FROM receipts_linearized" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + " )" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering ASC LIMIT ?" + ) + args = [ + user_id, user_id, + min_stream_ordering, max_stream_ordering, limit, + ] + txn.execute(sql, args) + return txn.fetchall() + no_read_receipt = yield self.runInteraction( + "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt + ) + + notifs = [ + { + "event_id": row[0], + "room_id": row[1], + "stream_ordering": row[2], + "actions": json.loads(row[3]), + } for row in after_read_receipt + no_read_receipt + ] + + # Now sort it so it's ordered correctly, since currently it will + # contain results from the first query, correctly ordered, followed + # by results from the second query, but we want them all ordered + # by stream_ordering, oldest first. + notifs.sort(key=lambda r: r['stream_ordering']) + + # Take only up to the limit. We have to stop at the limit because + # one of the subqueries may have hit the limit. + defer.returnValue(notifs[:limit]) + + @defer.inlineCallbacks + def get_unread_push_actions_for_user_in_range_for_email( + self, user_id, min_stream_ordering, max_stream_ordering, limit=20 + ): + """Get a list of the most recent unread push actions for a given user, + within the given stream ordering range. Called by the emailpusher + + Args: + user_id (str): The user to fetch push actions for. + min_stream_ordering(int): The exclusive lower bound on the + stream ordering of event push actions to fetch. + max_stream_ordering(int): The inclusive upper bound on the + stream ordering of event push actions to fetch. + limit (int): The maximum number of rows to return. Returns: A promise which resolves to a list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts". + The list will be ordered by descending received_ts. The list will have between 0~limit entries. """ # find rooms that have a read receipt in them and return the most recent # push actions def get_after_receipt(txn): - # XXX: Do we really need to GROUP BY user_id on the inner SELECT? - # XXX: NATURAL JOIN obfuscates which columns are being joined on the - # inner SELECT (the room_id and event_id), can we - # INNER JOIN ... USING instead? sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, " - "e.received_ts " - "FROM (" - " SELECT room_id, user_id, " - " max(topological_ordering) as topological_ordering, " - " max(stream_ordering) as stream_ordering " - " FROM events" - " NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'" - " GROUP BY room_id, user_id" + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " e.received_ts" + " FROM (" + " SELECT room_id," + " MAX(topological_ordering) as topological_ordering," + " MAX(stream_ordering) as stream_ordering" + " FROM events" + " INNER JOIN receipts_linearized USING (room_id, event_id)" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" ") AS rl," " event_push_actions AS ep" " INNER JOIN events AS e USING (room_id, event_id)" @@ -165,47 +274,47 @@ class EventPushActionsStore(SQLBaseStore): " )" " AND ep.stream_ordering > ?" " AND ep.user_id = ?" - " AND ep.user_id = rl.user_id" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering DESC LIMIT ?" ) - args = [min_stream_ordering, user_id] - if max_stream_ordering is not None: - sql += " AND ep.stream_ordering <= ?" - args.append(max_stream_ordering) - sql += " ORDER BY ep.stream_ordering DESC LIMIT ?" - args.append(limit) + args = [ + user_id, user_id, + min_stream_ordering, max_stream_ordering, limit, + ] txn.execute(sql, args) return txn.fetchall() after_read_receipt = yield self.runInteraction( - "get_unread_push_actions_for_user_in_range", get_after_receipt + "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt ) # There are rooms with push actions in them but you don't have a read receipt in # them e.g. rooms you've been invited to, so get push actions for rooms which do # not have read receipts in them too. def get_no_receipt(txn): - # XXX: Does the inner SELECT really need to select from the events table? - # We're just extracting the room_id, so isn't receipts_linearized enough? sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," " e.received_ts" " FROM event_push_actions AS ep" - " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" - " WHERE ep.room_id not in (" - " SELECT room_id FROM events NATURAL JOIN receipts_linearized" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - ") AND ep.user_id = ? AND ep.stream_ordering > ?" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE" + " ep.room_id NOT IN (" + " SELECT room_id FROM receipts_linearized" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + " )" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering DESC LIMIT ?" ) - args = [user_id, user_id, min_stream_ordering] - if max_stream_ordering is not None: - sql += " AND ep.stream_ordering <= ?" - args.append(max_stream_ordering) - sql += " ORDER BY ep.stream_ordering DESC LIMIT ?" - args.append(limit) + args = [ + user_id, user_id, + min_stream_ordering, max_stream_ordering, limit, + ] txn.execute(sql, args) return txn.fetchall() no_read_receipt = yield self.runInteraction( - "get_unread_push_actions_for_user_in_range", get_no_receipt + "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt ) # Make a list of dicts from the two sets of results. diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py new file mode 100644 index 0000000000..e9044afa2e --- /dev/null +++ b/tests/storage/test_event_push_actions.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import tests.unittest +import tests.utils + +USER_ID = "@user:example.com" + + +class EventPushActionsStoreTestCase(tests.unittest.TestCase): + + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_get_unread_push_actions_for_user_in_range_for_http(self): + yield self.store.get_unread_push_actions_for_user_in_range_for_http( + USER_ID, 0, 1000, 20 + ) + + @defer.inlineCallbacks + def test_get_unread_push_actions_for_user_in_range_for_email(self): + yield self.store.get_unread_push_actions_for_user_in_range_for_email( + USER_ID, 0, 1000, 20 + ) -- cgit 1.5.1 From 91fa69e0299167dad4df9831b8d175a99564b266 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 3 Aug 2016 11:48:32 +0100 Subject: keys/query: return all users which were asked for In the situation where all of a user's devices get deleted, we want to indicate this to a client, so we want to return an empty dictionary, rather than nothing at all. --- synapse/handlers/e2e_keys.py | 9 +++++--- tests/handlers/test_e2e_keys.py | 46 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 tests/handlers/test_e2e_keys.py (limited to 'tests') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 1312cdf5ab..950fc927b1 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -99,6 +99,7 @@ class E2eKeysHandler(object): """ local_query = [] + result_dict = {} for user_id, device_ids in query.items(): if not self.is_mine_id(user_id): logger.warning("Request for keys for non-local user %s", @@ -111,15 +112,17 @@ class E2eKeysHandler(object): for device_id in device_ids: local_query.append((user_id, device_id)) + # make sure that each queried user appears in the result dict + result_dict[user_id] = {} + results = yield self.store.get_e2e_device_keys(local_query) # un-jsonify the results - json_result = collections.defaultdict(dict) for user_id, device_keys in results.items(): for device_id, json_bytes in device_keys.items(): - json_result[user_id][device_id] = json.loads(json_bytes) + result_dict[user_id][device_id] = json.loads(json_bytes) - defer.returnValue(json_result) + defer.returnValue(result_dict) @defer.inlineCallbacks def on_federation_query_client_keys(self, query_body): diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py new file mode 100644 index 0000000000..878a54dc34 --- /dev/null +++ b/tests/handlers/test_e2e_keys.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +from twisted.internet import defer + +import synapse.api.errors +import synapse.handlers.e2e_keys + +import synapse.storage +from tests import unittest, utils + + +class E2eKeysHandlerTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs) + self.hs = None # type: synapse.server.HomeServer + self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler + + @defer.inlineCallbacks + def setUp(self): + self.hs = yield utils.setup_test_homeserver( + handlers=None, + replication_layer=mock.Mock(), + ) + self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) + + @defer.inlineCallbacks + def test_query_local_devices_no_devices(self): + """If the user has no devices, we expect an empty list. + """ + local_user = "@boris:" + self.hs.hostname + res = yield self.handler.query_local_devices({local_user: None}) + self.assertDictEqual(res, {local_user: {}}) -- cgit 1.5.1 From 68264d7404214d30d32310991c89ea4a234c319a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 3 Aug 2016 07:46:57 +0100 Subject: Include device name in /keys/query response Add an 'unsigned' section which includes the device display name. --- synapse/handlers/e2e_keys.py | 11 +++-- synapse/storage/end_to_end_keys.py | 60 ++++++++++++++++------- tests/storage/test_end_to_end_keys.py | 92 +++++++++++++++++++++++++++++++++++ 3 files changed, 143 insertions(+), 20 deletions(-) create mode 100644 tests/storage/test_end_to_end_keys.py (limited to 'tests') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 950fc927b1..bb69089b91 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -117,10 +117,15 @@ class E2eKeysHandler(object): results = yield self.store.get_e2e_device_keys(local_query) - # un-jsonify the results + # Build the result structure, un-jsonify the results, and add the + # "unsigned" section for user_id, device_keys in results.items(): - for device_id, json_bytes in device_keys.items(): - result_dict[user_id][device_id] = json.loads(json_bytes) + for device_id, device_info in device_keys.items(): + r = json.loads(device_info["key_json"]) + r["unsigned"] = { + "device_display_name": device_info["device_display_name"], + } + result_dict[user_id][device_id] = r defer.returnValue(result_dict) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 62b7790e91..5c8ed3e492 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections import twisted.internet.defer @@ -38,24 +39,49 @@ class EndToEndKeyStore(SQLBaseStore): query_list(list): List of pairs of user_ids and device_ids. Returns: Dict mapping from user-id to dict mapping from device_id to - key json byte strings. + dict containing "key_json", "device_display_name". """ - def _get_e2e_device_keys(txn): - result = {} - for user_id, device_id in query_list: - user_result = result.setdefault(user_id, {}) - keyvalues = {"user_id": user_id} - if device_id: - keyvalues["device_id"] = device_id - rows = self._simple_select_list_txn( - txn, table="e2e_device_keys_json", - keyvalues=keyvalues, - retcols=["device_id", "key_json"] - ) - for row in rows: - user_result[row["device_id"]] = row["key_json"] - return result - return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) + if not query_list: + return {} + + return self.runInteraction( + "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list + ) + + def _get_e2e_device_keys_txn(self, txn, query_list): + query_clauses = [] + query_params = [] + + for (user_id, device_id) in query_list: + query_clause = "k.user_id = ?" + query_params.append(user_id) + + if device_id: + query_clause += " AND k.device_id = ?" + query_params.append(device_id) + + query_clauses.append(query_clause) + + sql = ( + "SELECT k.user_id, k.device_id, " + " d.display_name AS device_display_name, " + " k.key_json" + " FROM e2e_device_keys_json k" + " LEFT JOIN devices d ON d.user_id = k.user_id" + " AND d.device_id = k.device_id" + " WHERE %s" + ) % ( + " OR ".join("("+q+")" for q in query_clauses) + ) + + txn.execute(sql, query_params) + rows = self.cursor_to_dict(txn) + + result = collections.defaultdict(dict) + for row in rows: + result[row["user_id"]][row["device_id"]] = row + + return result def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): def _add_e2e_one_time_keys(txn): diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py new file mode 100644 index 0000000000..0ebc6dafe8 --- /dev/null +++ b/tests/storage/test_end_to_end_keys.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import synapse.api.errors +import tests.unittest +import tests.utils + + +class EndToEndKeyStoreTestCase(tests.unittest.TestCase): + def __init__(self, *args, **kwargs): + super(EndToEndKeyStoreTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.DataStore + + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_key_without_device_name(self): + now = 1470174257070 + json = '{ "key": "value" }' + + yield self.store.set_e2e_device_keys( + "user", "device", now, json) + + res = yield self.store.get_e2e_device_keys((("user", "device"),)) + self.assertIn("user", res) + self.assertIn("device", res["user"]) + dev = res["user"]["device"] + self.assertDictContainsSubset({ + "key_json": json, + "device_display_name": None, + }, dev) + + @defer.inlineCallbacks + def test_get_key_with_device_name(self): + now = 1470174257070 + json = '{ "key": "value" }' + + yield self.store.set_e2e_device_keys( + "user", "device", now, json) + yield self.store.store_device( + "user", "device", "display_name" + ) + + res = yield self.store.get_e2e_device_keys((("user", "device"),)) + self.assertIn("user", res) + self.assertIn("device", res["user"]) + dev = res["user"]["device"] + self.assertDictContainsSubset({ + "key_json": json, + "device_display_name": "display_name", + }, dev) + + + @defer.inlineCallbacks + def test_multiple_devices(self): + now = 1470174257070 + + yield self.store.set_e2e_device_keys( + "user1", "device1", now, 'json11') + yield self.store.set_e2e_device_keys( + "user1", "device2", now, 'json12') + yield self.store.set_e2e_device_keys( + "user2", "device1", now, 'json21') + yield self.store.set_e2e_device_keys( + "user2", "device2", now, 'json22') + + res = yield self.store.get_e2e_device_keys((("user1", "device1"), + ("user2", "device2"))) + self.assertIn("user1", res) + self.assertIn("device1", res["user1"]) + self.assertNotIn("device2", res["user1"]) + self.assertIn("user2", res) + self.assertNotIn("device1", res["user2"]) + self.assertIn("device2", res["user2"]) -- cgit 1.5.1 From 98385888b891f544c68b749746604b593f98d729 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 3 Aug 2016 14:57:46 +0100 Subject: PEP8 --- synapse/storage/end_to_end_keys.py | 20 ++++++++++---------- tests/storage/test_end_to_end_keys.py | 2 -- 2 files changed, 10 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 5c8ed3e492..385d607056 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -63,16 +63,16 @@ class EndToEndKeyStore(SQLBaseStore): query_clauses.append(query_clause) sql = ( - "SELECT k.user_id, k.device_id, " - " d.display_name AS device_display_name, " - " k.key_json" - " FROM e2e_device_keys_json k" - " LEFT JOIN devices d ON d.user_id = k.user_id" - " AND d.device_id = k.device_id" - " WHERE %s" - ) % ( - " OR ".join("("+q+")" for q in query_clauses) - ) + "SELECT k.user_id, k.device_id, " + " d.display_name AS device_display_name, " + " k.key_json" + " FROM e2e_device_keys_json k" + " LEFT JOIN devices d ON d.user_id = k.user_id" + " AND d.device_id = k.device_id" + " WHERE %s" + ) % ( + " OR ".join("(" + q + ")" for q in query_clauses) + ) txn.execute(sql, query_params) rows = self.cursor_to_dict(txn) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 0ebc6dafe8..453bc61438 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -15,7 +15,6 @@ from twisted.internet import defer -import synapse.api.errors import tests.unittest import tests.utils @@ -68,7 +67,6 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): "device_display_name": "display_name", }, dev) - @defer.inlineCallbacks def test_multiple_devices(self): now = 1470174257070 -- cgit 1.5.1 From e97648c4e219d56fa7b46cdd763a588e84a23f02 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 4 Aug 2016 16:08:32 +0100 Subject: Test summarization --- synapse/rest/media/v1/preview_url_resource.py | 106 ++++++++++---------- tests/test_preview.py | 139 ++++++++++++++++++++++++++ 2 files changed, 193 insertions(+), 52 deletions(-) create mode 100644 tests/test_preview.py (limited to 'tests') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index ebd07d696b..b2e2e7e624 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -347,58 +347,7 @@ class PreviewUrlResource(Resource): re.sub(r'\s+', '\n', el.text).strip() for el in cloned_tree.iter() if el.text ) - - # Try to get a summary of between 200 and 500 words, respecting - # first paragraph and then word boundaries. - # TODO: Respect sentences? - MIN_SIZE = 200 - MAX_SIZE = 500 - - description = '' - - # Keep adding paragraphs until we get to the MIN_SIZE. - for text_node in text_nodes: - if len(description) < MIN_SIZE: - description += text_node + '\n' - else: - break - - description = description.strip() - description = re.sub(r'[\t ]+', ' ', description) - description = re.sub(r'[\t \r\n]*[\r\n]+', '\n', description) - - # If the concatenation of paragraphs to get above MIN_SIZE - # took us over MAX_SIZE, then we need to truncate mid paragraph - if len(description) > MAX_SIZE: - new_desc = "" - - # This splits the paragraph into words, but keeping the - # (preceeding) whitespace intact so we can easily concat - # words back together. - for match in re.finditer("\s*\S+", description): - word = match.group() - - # Keep adding words while the total length is less than - # MAX_SIZE. - if len(word) + len(new_desc) < MAX_SIZE: - new_desc += word - else: - # At this point the next word *will* take us over - # MAX_SIZE, but we also want to ensure that its not - # a huge word. If it is add it anyway and we'll - # truncate later. - if len(new_desc) < MIN_SIZE: - new_desc += word - break - - # Double check that we're not over the limit - if len(new_desc) > MAX_SIZE: - new_desc = new_desc[:MAX_SIZE] - - # We always add an ellipsis because at the very least - # we chopped mid paragraph. - description = new_desc.strip() + "…" - og['og:description'] = description if description else None + og['og:description'] = _summarize_paragraphs(text_nodes) # TODO: delete the url downloads to stop diskfilling, # as we only ever cared about its OG @@ -506,3 +455,56 @@ class PreviewUrlResource(Resource): content_type.startswith("application/xhtml") ): return True + + +def summarize_paragraphs(text_nodes, min_size=200, max_size=500): + # Try to get a summary of between 200 and 500 words, respecting + # first paragraph and then word boundaries. + # TODO: Respect sentences? + + description = '' + + # Keep adding paragraphs until we get to the MIN_SIZE. + for text_node in text_nodes: + if len(description) < min_size: + text_node = re.sub(r'[\t \r\n]+', ' ', text_node) + description += text_node + '\n\n' + else: + break + + description = description.strip() + description = re.sub(r'[\t ]+', ' ', description) + description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description) + + # If the concatenation of paragraphs to get above MIN_SIZE + # took us over MAX_SIZE, then we need to truncate mid paragraph + if len(description) > max_size: + new_desc = "" + + # This splits the paragraph into words, but keeping the + # (preceeding) whitespace intact so we can easily concat + # words back together. + for match in re.finditer("\s*\S+", description): + word = match.group() + + # Keep adding words while the total length is less than + # MAX_SIZE. + if len(word) + len(new_desc) < max_size: + new_desc += word + else: + # At this point the next word *will* take us over + # MAX_SIZE, but we also want to ensure that its not + # a huge word. If it is add it anyway and we'll + # truncate later. + if len(new_desc) < min_size: + new_desc += word + break + + # Double check that we're not over the limit + if len(new_desc) > max_size: + new_desc = new_desc[:max_size] + + # We always add an ellipsis because at the very least + # we chopped mid paragraph. + description = new_desc.strip() + "…" + return description if description else None diff --git a/tests/test_preview.py b/tests/test_preview.py new file mode 100644 index 0000000000..2a801173a0 --- /dev/null +++ b/tests/test_preview.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import unittest + +from synapse.rest.media.v1.preview_url_resource import summarize_paragraphs + + +class PreviewTestCase(unittest.TestCase): + + def test_long_summarize(self): + example_paras = [ + """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: + Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in + Troms county, Norway. The administrative centre of the municipality is + the city of Tromsø. Outside of Norway, Tromso and Tromsö are + alternative spellings of the city.Tromsø is considered the northernmost + city in the world with a population above 50,000. The most populous town + north of it is Alta, Norway, with a population of 14,272 (2013).""", + + """Tromsø lies in Northern Norway. The municipality has a population of + (2015) 72,066, but with an annual influx of students it has over 75,000 + most of the year. It is the largest urban area in Northern Norway and the + third largest north of the Arctic Circle (following Murmansk and Norilsk). + Most of Tromsø, including the city centre, is located on the island of + Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012, + Tromsøya had a population of 36,088. Substantial parts of the urban area + are also situated on the mainland to the east, and on parts of Kvaløya—a + large island to the west. Tromsøya is connected to the mainland by the Tromsø + Bridge and the Tromsøysund Tunnel, and to the island of Kvaløya by the + Sandnessund Bridge. Tromsø Airport connects the city to many destinations + in Europe. The city is warmer than most other places located on the same + latitude, due to the warming effect of the Gulf Stream.""", + + """The city centre of Tromsø contains the highest number of old wooden + houses in Northern Norway, the oldest house dating from 1789. The Arctic + Cathedral, a modern church from 1965, is probably the most famous landmark + in Tromsø. The city is a cultural centre for its region, with several + festivals taking place in the summer. Some of Norway's best-known + musicians, Torbjørn Brundtland and Svein Berge of the electronica duo + Röyksopp and Lene Marlin grew up and started their careers in Tromsø. + Noted electronic musician Geir Jenssen also hails from Tromsø.""", + ] + + desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) + + self.assertEquals( + desc, + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway. The administrative centre of the municipality is" + " the city of Tromsø. Outside of Norway, Tromso and Tromsö are" + " alternative spellings of the city.Tromsø is considered the northernmost" + " city in the world with a population above 50,000. The most populous town" + " north of it is Alta, Norway, with a population of 14,272 (2013)." + ) + + desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) + + self.assertEquals( + desc, + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year. It is the largest urban area in Northern Norway and the" + " third largest north of the Arctic Circle (following Murmansk and Norilsk)." + " Most of Tromsø, including the city centre, is located on the island of" + " Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," + " Tromsøya had a population of 36,088. Substantial parts of the…" + ) + + def test_short_summarize(self): + example_paras = [ + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.", + + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year.", + + "The city centre of Tromsø contains the highest number of old wooden" + " houses in Northern Norway, the oldest house dating from 1789. The Arctic" + " Cathedral, a modern church from 1965, is probably the most famous landmark" + " in Tromsø.", + ] + + desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) + + self.assertEquals( + desc, + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.\n" + "\n" + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year." + ) + + def test_small_then_large_summarize(self): + example_paras = [ + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.", + + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year." + " The city centre of Tromsø contains the highest number of old wooden" + " houses in Northern Norway, the oldest house dating from 1789. The Arctic" + " Cathedral, a modern church from 1965, is probably the most famous landmark" + " in Tromsø.", + ] + + desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) + self.assertEquals( + desc, + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.\n" + "\n" + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year. The city centre of Tromsø contains the highest number" + " of old wooden houses in Northern Norway, the oldest house dating from" + " 1789. The Arctic Cathedral, a modern church…" + ) -- cgit 1.5.1 From 6fe6a6f0299c97086a552eda75570eaa66ff2598 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 8 Aug 2016 16:34:07 +0100 Subject: Fix login with m.login.token login with token (as used by CAS auth) was broken by 067596d, such that it always returned a 401. --- synapse/api/auth.py | 45 +++++++++++++++++++++++++------------- synapse/handlers/auth.py | 17 ++++----------- synapse/server.pyi | 4 ++++ tests/handlers/test_auth.py | 53 +++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 87 insertions(+), 32 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 59db76debc..0db26fcfd7 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -675,27 +675,18 @@ class Auth(object): try: macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) - user_prefix = "user_id = " - user = None - user_id = None - guest = False - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(user_prefix): - user_id = caveat.caveat_id[len(user_prefix):] - user = UserID.from_string(user_id) - elif caveat.caveat_id == "guest = true": - guest = True + user_id = self.get_user_id_from_macaroon(macaroon) + user = UserID.from_string(user_id) self.validate_macaroon( macaroon, rights, self.hs.config.expire_access_token, user_id=user_id, ) - if user is None: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", - errcode=Codes.UNKNOWN_TOKEN - ) + guest = False + for caveat in macaroon.caveats: + if caveat.caveat_id == "guest = true": + guest = True if guest: ret = { @@ -743,6 +734,29 @@ class Auth(object): errcode=Codes.UNKNOWN_TOKEN ) + def get_user_id_from_macaroon(self, macaroon): + """Retrieve the user_id given by the caveats on the macaroon. + + Does *not* validate the macaroon. + + Args: + macaroon (pymacaroons.Macaroon): The macaroon to validate + + Returns: + (str) user id + + Raises: + AuthError if there is no user_id caveat in the macaroon + """ + user_prefix = "user_id = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(user_prefix): + return caveat.caveat_id[len(user_prefix):] + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN + ) + def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): """ validate that a Macaroon is understood by and was signed by this server. @@ -754,6 +768,7 @@ class Auth(object): verify_expiry(bool): Whether to verify whether the macaroon has expired. This should really always be True, but no clients currently implement token refresh, so we can't enforce expiry yet. + user_id (str): The user_id required """ v = pymacaroons.Verifier() v.satisfy_exact("gen = 1") diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 2e138f328f..1d3641b7a7 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -720,10 +720,11 @@ class AuthHandler(BaseHandler): def validate_short_term_login_token_and_get_user_id(self, login_token): try: - macaroon = pymacaroons.Macaroon.deserialize(login_token) auth_api = self.hs.get_auth() - auth_api.validate_macaroon(macaroon, "login", True) - return self.get_user_from_macaroon(macaroon) + macaroon = pymacaroons.Macaroon.deserialize(login_token) + user_id = auth_api.get_user_id_from_macaroon(macaroon) + auth_api.validate_macaroon(macaroon, "login", True, user_id) + return user_id except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) @@ -736,16 +737,6 @@ class AuthHandler(BaseHandler): macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon - def get_user_from_macaroon(self, macaroon): - user_prefix = "user_id = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(user_prefix): - return caveat.caveat_id[len(user_prefix):] - raise AuthError( - self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token", - errcode=Codes.UNKNOWN_TOKEN - ) - @defer.inlineCallbacks def set_password(self, user_id, newpassword, requester=None): password_hash = self.hash(newpassword) diff --git a/synapse/server.pyi b/synapse/server.pyi index c0aa868c4f..9570df5537 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,3 +1,4 @@ +import synapse.api.auth import synapse.handlers import synapse.handlers.auth import synapse.handlers.device @@ -6,6 +7,9 @@ import synapse.storage import synapse.state class HomeServer(object): + def get_auth(self) -> synapse.api.auth.Auth: + pass + def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: pass diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 21077cbe9a..366cf97ae4 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -14,12 +14,13 @@ # limitations under the License. import pymacaroons +from twisted.internet import defer +import synapse +import synapse.api.errors from synapse.handlers.auth import AuthHandler from tests import unittest from tests.utils import setup_test_homeserver -from twisted.internet import defer - class AuthHandlers(object): def __init__(self, hs): @@ -31,11 +32,12 @@ class AuthTestCase(unittest.TestCase): def setUp(self): self.hs = yield setup_test_homeserver(handlers=None) self.hs.handlers = AuthHandlers(self.hs) + self.auth_handler = self.hs.handlers.auth_handler def test_token_is_a_macaroon(self): self.hs.config.macaroon_secret_key = "this key is a huge secret" - token = self.hs.handlers.auth_handler.generate_access_token("some_user") + token = self.auth_handler.generate_access_token("some_user") # Check that we can parse the thing with pymacaroons macaroon = pymacaroons.Macaroon.deserialize(token) # The most basic of sanity checks @@ -46,7 +48,7 @@ class AuthTestCase(unittest.TestCase): self.hs.config.macaroon_secret_key = "this key is a massive secret" self.hs.clock.now = 5000 - token = self.hs.handlers.auth_handler.generate_access_token("a_user") + token = self.auth_handler.generate_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) def verify_gen(caveat): @@ -67,3 +69,46 @@ class AuthTestCase(unittest.TestCase): v.satisfy_general(verify_type) v.satisfy_general(verify_expiry) v.verify(macaroon, self.hs.config.macaroon_secret_key) + + def test_short_term_login_token_gives_user_id(self): + self.hs.clock.now = 1000 + + token = self.auth_handler.generate_short_term_login_token( + "a_user", 5000 + ) + + self.assertEqual( + "a_user", + self.auth_handler.validate_short_term_login_token_and_get_user_id( + token + ) + ) + + # when we advance the clock, the token should be rejected + self.hs.clock.now = 6000 + with self.assertRaises(synapse.api.errors.AuthError): + self.auth_handler.validate_short_term_login_token_and_get_user_id( + token + ) + + def test_short_term_login_token_cannot_replace_user_id(self): + token = self.auth_handler.generate_short_term_login_token( + "a_user", 5000 + ) + macaroon = pymacaroons.Macaroon.deserialize(token) + + self.assertEqual( + "a_user", + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) + ) + + # add another "user_id" caveat, which might allow us to override the + # user_id. + macaroon.add_first_party_caveat("user_id = b_user") + + with self.assertRaises(synapse.api.errors.AuthError): + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) -- cgit 1.5.1 From a8bcc7274d42e27c9d2966e908396cda91c379f4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 8 Aug 2016 17:20:38 +0100 Subject: PEP8 --- tests/handlers/test_auth.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 366cf97ae4..4a8cd19acf 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -22,6 +22,7 @@ from synapse.handlers.auth import AuthHandler from tests import unittest from tests.utils import setup_test_homeserver + class AuthHandlers(object): def __init__(self, hs): self.auth_handler = AuthHandler(hs) -- cgit 1.5.1