From 0a32208e5dde4980a5962f17e9b27f2e28e1f3f1 Mon Sep 17 00:00:00 2001 From: Martin Weinelt Date: Mon, 6 Jun 2016 02:05:57 +0200 Subject: Rework ldap integration with ldap3 Use the pure-python ldap3 library, which eliminates the need for a system dependency. Offer both a `search` and `simple_bind` mode, for more sophisticated ldap scenarios. - `search` tries to find a matching DN within the `user_base` while employing the `user_filter`, then tries the bind when a single matching DN was found. - `simple_bind` tries the bind against a specific DN by combining the localpart and `user_base` Offer support for STARTTLS on a plain connection. The configuration was changed to reflect these new possibilities. Signed-off-by: Martin Weinelt --- tests/utils.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/utils.py b/tests/utils.py index 6e41ae1ff6..ed547bc39b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -56,6 +56,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.use_frozen_dicts = True config.database_config = {"name": "sqlite3"} + config.ldap_enabled = False if "clock" not in kargs: kargs["clock"] = MockClock() -- cgit 1.4.1 From 2455ad8468ea3d372d0f3b3828efa10419ad68ad Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 24 Jun 2016 13:34:20 +0100 Subject: Remove room name & alias test as get_room_name_and_alias is now gone --- tests/replication/slave/storage/test_events.py | 41 -------------------------- 1 file changed, 41 deletions(-) (limited to 'tests') diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 17587fda00..f33e6f60fb 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -58,47 +58,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def tearDown(self): [unpatch() for unpatch in self.unpatches] - @defer.inlineCallbacks - def test_room_name_and_aliases(self): - create = yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.persist(type="m.room.member", key=USER_ID, membership="join") - yield self.persist(type="m.room.name", key="", name="name1") - yield self.persist( - type="m.room.aliases", key="blue", aliases=["#1:blue"] - ) - yield self.replicate() - yield self.check( - "get_room_name_and_aliases", (ROOM_ID,), ("name1", ["#1:blue"]) - ) - - # Set the room name. - yield self.persist(type="m.room.name", key="", name="name2") - yield self.replicate() - yield self.check( - "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#1:blue"]) - ) - - # Set the room aliases. - yield self.persist( - type="m.room.aliases", key="blue", aliases=["#2:blue"] - ) - yield self.replicate() - yield self.check( - "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#2:blue"]) - ) - - # Leave and join the room clobbering the state. - yield self.persist(type="m.room.member", key=USER_ID, membership="leave") - yield self.persist( - type="m.room.member", key=USER_ID, membership="join", - reset_state=[create] - ) - yield self.replicate() - - yield self.check( - "get_room_name_and_aliases", (ROOM_ID,), (None, []) - ) - @defer.inlineCallbacks def test_room_members(self): create = yield self.persist(type="m.room.create", key="", creator=USER_ID) -- cgit 1.4.1 From 7335f0addae9ff473403eaaffd7d2b02a9f1991f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Jul 2016 14:44:25 +0100 Subject: Add ReadWriteLock --- synapse/util/async.py | 82 +++++++++++++++++++++++++++++++++++++++++++++ tests/util/test_rwlock.py | 85 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 167 insertions(+) create mode 100644 tests/util/test_rwlock.py (limited to 'tests') diff --git a/synapse/util/async.py b/synapse/util/async.py index 40be7fe7e3..c84b23ff46 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -194,3 +194,85 @@ class Linearizer(object): self.key_to_defer.pop(key, None) defer.returnValue(_ctx_manager()) + + +class ReadWriteLock(object): + """A deferred style read write lock. + + Example: + + with (yield read_write_lock.read("test_key")): + # do some work + """ + + # IMPLEMENTATION NOTES + # + # We track the most recent queued reader and writer deferreds (which get + # resolved when they release the lock). + # + # Read: We know its safe to acquire a read lock when the latest writer has + # been resolved. The new reader is appeneded to the list of latest readers. + # + # Write: We know its safe to acquire the write lock when both the latest + # writers and readers have been resolved. The new writer replaces the latest + # writer. + + def __init__(self): + # Latest readers queued + self.key_to_current_readers = {} + + # Latest writer queued + self.key_to_current_writer = {} + + @defer.inlineCallbacks + def read(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.setdefault(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + curr_readers.add(new_defer) + + # We wait for the latest writer to finish writing. We can safely ignore + # any existing readers... as they're readers. + yield curr_writer + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + self.key_to_current_readers.get(key, set()).discard(new_defer) + + defer.returnValue(_ctx_manager()) + + @defer.inlineCallbacks + def write(self, key): + new_defer = defer.Deferred() + + curr_readers = self.key_to_current_readers.get(key, set()) + curr_writer = self.key_to_current_writer.get(key, None) + + # We wait on all latest readers and writer. + to_wait_on = list(curr_readers) + if curr_writer: + to_wait_on.append(curr_writer) + + # We can clear the list of current readers since the new writer waits + # for them to finish. + curr_readers.clear() + self.key_to_current_writer[key] = new_defer + + yield defer.gatherResults(to_wait_on) + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + new_defer.callback(None) + if self.key_to_current_writer[key] == new_defer: + self.key_to_current_writer.pop(key) + + defer.returnValue(_ctx_manager()) diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py new file mode 100644 index 0000000000..1d745ae1a7 --- /dev/null +++ b/tests/util/test_rwlock.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests import unittest + +from synapse.util.async import ReadWriteLock + + +class ReadWriteLockTestCase(unittest.TestCase): + + def _assert_called_before_not_after(self, lst, first_false): + for i, d in enumerate(lst[:first_false]): + self.assertTrue(d.called, msg="%d was unexpectedly false" % i) + + for i, d in enumerate(lst[first_false:]): + self.assertFalse( + d.called, msg="%d was unexpectedly true" % (i + first_false) + ) + + def test_rwlock(self): + rwlock = ReadWriteLock() + + key = object() + + ds = [ + rwlock.read(key), # 0 + rwlock.read(key), # 1 + rwlock.write(key), # 2 + rwlock.write(key), # 3 + rwlock.read(key), # 4 + rwlock.read(key), # 5 + rwlock.write(key), # 6 + ] + + self._assert_called_before_not_after(ds, 2) + + with ds[0].result: + self._assert_called_before_not_after(ds, 2) + self._assert_called_before_not_after(ds, 2) + + with ds[1].result: + self._assert_called_before_not_after(ds, 2) + self._assert_called_before_not_after(ds, 3) + + with ds[2].result: + self._assert_called_before_not_after(ds, 3) + self._assert_called_before_not_after(ds, 4) + + with ds[3].result: + self._assert_called_before_not_after(ds, 4) + self._assert_called_before_not_after(ds, 6) + + with ds[5].result: + self._assert_called_before_not_after(ds, 6) + self._assert_called_before_not_after(ds, 6) + + with ds[4].result: + self._assert_called_before_not_after(ds, 6) + self._assert_called_before_not_after(ds, 7) + + with ds[6].result: + pass + + d = rwlock.write(key) + self.assertTrue(d.called) + with d.result: + pass + + d = rwlock.read(key) + self.assertTrue(d.called) + with d.result: + pass -- cgit 1.4.1 From 0136a522b18a734db69171d60566f501c0ced663 Mon Sep 17 00:00:00 2001 From: Negar Fazeli Date: Fri, 8 Jul 2016 16:53:18 +0200 Subject: Bug fix: expire invalid access tokens --- synapse/api/auth.py | 3 +++ synapse/handlers/auth.py | 5 +++-- synapse/handlers/register.py | 6 +++--- synapse/rest/client/v1/register.py | 2 +- tests/api/test_auth.py | 31 ++++++++++++++++++++++++++++++- tests/handlers/test_register.py | 4 ++-- 6 files changed, 42 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index a4d658a9d0..521a52e001 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -629,7 +629,10 @@ class Auth(object): except AuthError: # TODO(daniel): Remove this fallback when all existing access tokens # have been re-issued as macaroons. + if self.hs.config.expire_access_token: + raise ret = yield self._look_up_user_by_access_token(token) + defer.returnValue(ret) @defer.inlineCallbacks diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index e259213a36..5a0ed9d6b9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -637,12 +637,13 @@ class AuthHandler(BaseHandler): yield self.store.add_refresh_token_to_user(user_id, refresh_token) defer.returnValue(refresh_token) - def generate_access_token(self, user_id, extra_caveats=None): + def generate_access_token(self, user_id, extra_caveats=None, + duration_in_ms=(60 * 60 * 1000)): extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") now = self.hs.get_clock().time_msec() - expiry = now + (60 * 60 * 1000) + expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 8c3381df8a..6b33b27149 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -360,7 +360,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, localpart, displayname, duration_seconds, + def get_or_create_user(self, localpart, displayname, duration_in_ms, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -390,8 +390,8 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.auth_handler().generate_short_term_login_token( - user_id, duration_seconds) + token = self.auth_handler().generate_access_token( + user_id, None, duration_in_ms) if need_register: yield self.store.register( diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index ce7099b18f..8e1f1b7845 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -429,7 +429,7 @@ class CreateUserRestServlet(ClientV1RestServlet): user_id, token = yield handler.get_or_create_user( localpart=localpart, displayname=displayname, - duration_seconds=duration_seconds, + duration_in_ms=(duration_seconds * 1000), password_hash=password_hash ) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index ad269af0ec..960c23d631 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -281,7 +281,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user,)) - macaroon.add_first_party_caveat("time < 1") # ms + macaroon.add_first_party_caveat("time < -2000") # ms self.hs.clock.now = 5000 # seconds self.hs.config.expire_access_token = True @@ -293,3 +293,32 @@ class AuthTestCase(unittest.TestCase): yield self.auth.get_user_from_macaroon(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("Invalid macaroon", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_with_valid_duration(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user_id = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + macaroon.add_first_party_caveat("time < 900000000") # ms + + self.hs.clock.now = 5000 # seconds + self.hs.config.expire_access_token = True + + user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize()) + user = user_info["user"] + self.assertEqual(UserID.from_string(user_id), user) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 69a5e5b1d4..a7de3c7c17 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -42,12 +42,12 @@ class RegistrationTestCase(unittest.TestCase): http_client=None, expire_access_token=True) self.auth_handler = Mock( - generate_short_term_login_token=Mock(return_value='secret')) + generate_access_token=Mock(return_value='secret')) self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler self.hs.get_handlers().profile_handler = Mock() self.mock_handler = Mock(spec=[ - "generate_short_term_login_token", + "generate_access_token", ]) self.hs.get_auth_handler = Mock(return_value=self.auth_handler) -- cgit 1.4.1 From a98d2152049b0a61426ed3d8b6ac872a9ca3f535 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 14 Jul 2016 15:59:25 +0100 Subject: Add filter param to /messages API --- synapse/handlers/message.py | 16 ++++++++++++---- synapse/rest/client/v1/room.py | 11 ++++++++++- tests/storage/event_injector.py | 1 + tests/storage/test_events.py | 12 ++++++------ 4 files changed, 29 insertions(+), 11 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ad2753c1b5..dc76d34a52 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -66,7 +66,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, - as_client_event=True): + as_client_event=True, event_filter=None): """Get messages in a room. Args: @@ -75,11 +75,11 @@ class MessageHandler(BaseHandler): pagin_config (synapse.api.streams.PaginationConfig): The pagination config rules to apply, if any. as_client_event (bool): True to get events in client-server format. + event_filter (Filter): Filter to apply to results or None Returns: dict: Pagination API results """ user_id = requester.user.to_string() - data_source = self.hs.get_event_sources().sources["room"] if pagin_config.from_token: room_token = pagin_config.from_token.room_key @@ -129,8 +129,13 @@ class MessageHandler(BaseHandler): room_id, max_topo ) - events, next_key = yield data_source.get_pagination_rows( - requester.user, source_config, room_id + events, next_key = yield self.store.paginate_room_events( + room_id=room_id, + from_key=source_config.from_key, + to_key=source_config.to_key, + direction=source_config.direction, + limit=source_config.limit, + event_filter=event_filter, ) next_token = pagin_config.from_token.copy_and_replace( @@ -144,6 +149,9 @@ class MessageHandler(BaseHandler): "end": next_token.to_string(), }) + if event_filter: + events = event_filter.filter(events) + events = yield filter_events_for_client( self.store, user_id, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 86fbe2747d..866a1e9120 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -20,12 +20,14 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.api.errors import SynapseError, Codes, AuthError from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership +from synapse.api.filtering import Filter from synapse.types import UserID, RoomID, RoomAlias from synapse.events.utils import serialize_event from synapse.http.servlet import parse_json_object_from_request import logging import urllib +import ujson as json logger = logging.getLogger(__name__) @@ -327,12 +329,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet): request, default_limit=10, ) as_client_event = "raw" not in request.args + filter_bytes = request.args.get("filter", None) + if filter_bytes: + filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8") + event_filter = Filter(json.loads(filter_json)) + else: + event_filter = None handler = self.handlers.message_handler msgs = yield handler.get_messages( room_id=room_id, requester=requester, pagin_config=pagination_config, - as_client_event=as_client_event + as_client_event=as_client_event, + event_filter=event_filter, ) defer.returnValue((200, msgs)) diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py index f22ba8db89..38556da9a7 100644 --- a/tests/storage/event_injector.py +++ b/tests/storage/event_injector.py @@ -30,6 +30,7 @@ class EventInjector: def create_room(self, room): builder = self.event_builder_factory.new({ "type": EventTypes.Create, + "sender": "", "room_id": room.to_string(), "content": {}, }) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 18a6cff0c7..3762b38e37 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -37,7 +37,7 @@ class EventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_count_daily_messages(self): - self.db_pool.runQuery("DELETE FROM stats_reporting") + yield self.db_pool.runQuery("DELETE FROM stats_reporting") self.hs.clock.now = 100 @@ -60,7 +60,7 @@ class EventsStoreTestCase(unittest.TestCase): # it isn't old enough. count = yield self.store.count_daily_messages() self.assertIsNone(count) - self._assert_stats_reporting(1, self.hs.clock.now) + yield self._assert_stats_reporting(1, self.hs.clock.now) # Already reported yesterday, two new events from today. yield self.event_injector.inject_message(room, user, "Yeah they are!") @@ -68,21 +68,21 @@ class EventsStoreTestCase(unittest.TestCase): self.hs.clock.now += 60 * 60 * 24 count = yield self.store.count_daily_messages() self.assertEqual(2, count) # 2 since yesterday - self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever + yield self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever # Last reported too recently. yield self.event_injector.inject_message(room, user, "Who could disagree?") self.hs.clock.now += 60 * 60 * 22 count = yield self.store.count_daily_messages() self.assertIsNone(count) - self._assert_stats_reporting(4, self.hs.clock.now) + yield self._assert_stats_reporting(4, self.hs.clock.now) # Last reported too long ago yield self.event_injector.inject_message(room, user, "No one.") self.hs.clock.now += 60 * 60 * 26 count = yield self.store.count_daily_messages() self.assertIsNone(count) - self._assert_stats_reporting(5, self.hs.clock.now) + yield self._assert_stats_reporting(5, self.hs.clock.now) # And now let's actually report something yield self.event_injector.inject_message(room, user, "Indeed.") @@ -92,7 +92,7 @@ class EventsStoreTestCase(unittest.TestCase): self.hs.clock.now += (60 * 60 * 24) + 50 count = yield self.store.count_daily_messages() self.assertEqual(3, count) - self._assert_stats_reporting(8, self.hs.clock.now) + yield self._assert_stats_reporting(8, self.hs.clock.now) @defer.inlineCallbacks def _get_last_stream_token(self): -- cgit 1.4.1 From f863a52ceacf69ab19b073383be80603a2f51c0a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 Jul 2016 13:19:07 +0100 Subject: Add device_id support to /login Add a 'devices' table to the storage, as well as a 'device_id' column to refresh_tokens. Allow the client to pass a device_id, and initial_device_display_name, to /login. If login is successful, then register the device in the devices table if it wasn't known already. If no device_id was supplied, make one up. Associate the device_id with the access token and refresh token, so that we can get at it again later. Ensure that the device_id is copied from the refresh token to the access_token when the token is refreshed. --- synapse/handlers/auth.py | 19 +++--- synapse/handlers/device.py | 71 ++++++++++++++++++++ synapse/rest/client/v1/login.py | 39 ++++++++++- synapse/rest/client/v2_alpha/tokenrefresh.py | 10 ++- synapse/server.py | 5 ++ synapse/storage/__init__.py | 3 + synapse/storage/devices.py | 77 ++++++++++++++++++++++ synapse/storage/registration.py | 28 +++++--- synapse/storage/schema/delta/33/devices.sql | 21 ++++++ .../schema/delta/33/refreshtoken_device.sql | 16 +++++ tests/handlers/test_device.py | 75 +++++++++++++++++++++ tests/storage/test_registration.py | 21 ++++-- 12 files changed, 354 insertions(+), 31 deletions(-) create mode 100644 synapse/handlers/device.py create mode 100644 synapse/storage/devices.py create mode 100644 synapse/storage/schema/delta/33/devices.sql create mode 100644 synapse/storage/schema/delta/33/refreshtoken_device.sql create mode 100644 tests/handlers/test_device.py (limited to 'tests') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 983994fa95..ce9bc18849 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -361,7 +361,7 @@ class AuthHandler(BaseHandler): return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id): + def get_login_tuple_for_user_id(self, user_id, device_id=None): """ Gets login tuple for the user with the given user ID. @@ -372,6 +372,7 @@ class AuthHandler(BaseHandler): Args: user_id (str): canonical User ID + device_id (str): the device ID to associate with the access token Returns: A tuple of: The access token for the user's session. @@ -380,9 +381,9 @@ class AuthHandler(BaseHandler): StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - logger.info("Logging in user %s", user_id) - access_token = yield self.issue_access_token(user_id) - refresh_token = yield self.issue_refresh_token(user_id) + logger.info("Logging in user %s on device %s", user_id, device_id) + access_token = yield self.issue_access_token(user_id, device_id) + refresh_token = yield self.issue_refresh_token(user_id, device_id) defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks @@ -638,15 +639,17 @@ class AuthHandler(BaseHandler): defer.returnValue(False) @defer.inlineCallbacks - def issue_access_token(self, user_id): + def issue_access_token(self, user_id, device_id=None): access_token = self.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token) + yield self.store.add_access_token_to_user(user_id, access_token, + device_id) defer.returnValue(access_token) @defer.inlineCallbacks - def issue_refresh_token(self, user_id): + def issue_refresh_token(self, user_id, device_id=None): refresh_token = self.generate_refresh_token(user_id) - yield self.store.add_refresh_token_to_user(user_id, refresh_token) + yield self.store.add_refresh_token_to_user(user_id, refresh_token, + device_id) defer.returnValue(refresh_token) def generate_access_token(self, user_id, extra_caveats=None, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py new file mode 100644 index 0000000000..8d7d9874f8 --- /dev/null +++ b/synapse/handlers/device.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.errors import StoreError +from synapse.util import stringutils +from twisted.internet import defer +from ._base import BaseHandler + +import logging + +logger = logging.getLogger(__name__) + + +class DeviceHandler(BaseHandler): + def __init__(self, hs): + super(DeviceHandler, self).__init__(hs) + + @defer.inlineCallbacks + def check_device_registered(self, user_id, device_id, + initial_device_display_name): + """ + If the given device has not been registered, register it with the + supplied display name. + + If no device_id is supplied, we make one up. + + Args: + user_id (str): @user:id + device_id (str | None): device id supplied by client + initial_device_display_name (str | None): device display name from + client + Returns: + str: device id (generated if none was supplied) + """ + if device_id is not None: + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=True, + ) + defer.returnValue(device_id) + + # if the device id is not specified, we'll autogen one, but loop a few + # times in case of a clash. + attempts = 0 + while attempts < 5: + try: + device_id = stringutils.random_string_with_symbols(16) + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=False, + ) + defer.returnValue(device_id) + except StoreError: + attempts += 1 + + raise StoreError(500, "Couldn't generate a device ID.") diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a1f2ba8773..e8b791519c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet): self.servername = hs.config.server_name self.http_client = hs.get_simple_http_client() self.auth_handler = self.hs.get_auth_handler() + self.device_handler = self.hs.get_device_handler() def on_GET(self, request): flows = [] @@ -149,14 +150,16 @@ class LoginRestServlet(ClientV1RestServlet): user_id=user_id, password=login_submission["password"], ) + device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -168,14 +171,16 @@ class LoginRestServlet(ClientV1RestServlet): user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) + device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -252,8 +257,13 @@ class LoginRestServlet(ClientV1RestServlet): auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if registered_user_id: + device_id = yield self._register_device( + registered_user_id, login_submission + ) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(registered_user_id) + yield auth_handler.get_login_tuple_for_user_id( + registered_user_id, device_id + ) ) result = { "user_id": registered_user_id, @@ -262,6 +272,9 @@ class LoginRestServlet(ClientV1RestServlet): "home_server": self.hs.hostname, } else: + # TODO: we should probably check that the register isn't going + # to fonx/change our user_id before registering the device + device_id = yield self._register_device(user_id, login_submission) user_id, access_token = ( yield self.handlers.registration_handler.register(localpart=user) ) @@ -300,6 +313,26 @@ class LoginRestServlet(ClientV1RestServlet): return (user, attributes) + def _register_device(self, user_id, login_submission): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) login_submission: dictionary supplied to /login call, from + which we pull device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get( + "initial_device_display_name") + return self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + class SAML2RestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login/saml2", releases=()) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 8270e8787f..0d312c91d4 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -39,9 +39,13 @@ class TokenRefreshRestServlet(RestServlet): try: old_refresh_token = body["refresh_token"] auth_handler = self.hs.get_auth_handler() - (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( - old_refresh_token, auth_handler.generate_refresh_token) - new_access_token = yield auth_handler.issue_access_token(user_id) + refresh_result = yield self.store.exchange_refresh_token( + old_refresh_token, auth_handler.generate_refresh_token + ) + (user_id, new_refresh_token, device_id) = refresh_result + new_access_token = yield auth_handler.issue_access_token( + user_id, device_id + ) defer.returnValue((200, { "access_token": new_access_token, "refresh_token": new_refresh_token, diff --git a/synapse/server.py b/synapse/server.py index d49a1a8a96..e8b166990d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -25,6 +25,7 @@ from twisted.enterprise import adbapi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.appservice.api import ApplicationServiceApi from synapse.federation import initialize_http_replication +from synapse.handlers.device import DeviceHandler from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.notifier import Notifier from synapse.api.auth import Auth @@ -92,6 +93,7 @@ class HomeServer(object): 'typing_handler', 'room_list_handler', 'auth_handler', + 'device_handler', 'application_service_api', 'application_service_scheduler', 'application_service_handler', @@ -197,6 +199,9 @@ class HomeServer(object): def build_auth_handler(self): return AuthHandler(self) + def build_device_handler(self): + return DeviceHandler(self) + def build_application_service_api(self): return ApplicationServiceApi(self) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 1c93e18f9d..73fb334dd6 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -14,6 +14,8 @@ # limitations under the License. from twisted.internet import defer + +from synapse.storage.devices import DeviceStore from .appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore ) @@ -80,6 +82,7 @@ class DataStore(RoomMemberStore, RoomStore, EventPushActionsStore, OpenIdStore, ClientIpStore, + DeviceStore, ): def __init__(self, db_conn, hs): diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py new file mode 100644 index 0000000000..9065e96d28 --- /dev/null +++ b/synapse/storage/devices.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from ._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class DeviceStore(SQLBaseStore): + @defer.inlineCallbacks + def store_device(self, user_id, device_id, + initial_device_display_name, + ignore_if_known=True): + """Ensure the given device is known; add it to the store if not + + Args: + user_id (str): id of user associated with the device + device_id (str): id of device + initial_device_display_name (str): initial displayname of the + device + ignore_if_known (bool): ignore integrity errors which mean the + device is already known + Returns: + defer.Deferred + Raises: + StoreError: if ignore_if_known is False and the device was already + known + """ + try: + yield self._simple_insert( + "devices", + values={ + "user_id": user_id, + "device_id": device_id, + "display_name": initial_device_display_name + }, + desc="store_device", + or_ignore=ignore_if_known, + ) + except Exception as e: + logger.error("store_device with device_id=%s failed: %s", + device_id, e) + raise StoreError(500, "Problem storing device.") + + def get_device(self, user_id, device_id): + """Retrieve a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred for a namedtuple containing the device information + Raises: + StoreError: if the device is not found + """ + return self._simple_select_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcols=("user_id", "device_id", "display_name"), + desc="get_device", + ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index d957a629dc..26ef1cfd8a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -31,12 +31,14 @@ class RegistrationStore(SQLBaseStore): self.clock = hs.get_clock() @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token): + def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. Args: user_id (str): The user ID. token (str): The new access token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -47,18 +49,21 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_access_token_to_user", ) @defer.inlineCallbacks - def add_refresh_token_to_user(self, user_id, token): + def add_refresh_token_to_user(self, user_id, token, device_id=None): """Adds a refresh token for the given user. Args: user_id (str): The user ID. token (str): The new refresh token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -69,7 +74,8 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_refresh_token_to_user", ) @@ -291,18 +297,18 @@ class RegistrationStore(SQLBaseStore): ) def exchange_refresh_token(self, refresh_token, token_generator): - """Exchange a refresh token for a new access token and refresh token. + """Exchange a refresh token for a new one. Doing so invalidates the old refresh token - refresh tokens are single use. Args: - token (str): The refresh token of a user. + refresh_token (str): The refresh token of a user. token_generator (fn: str -> str): Function which, when given a user ID, returns a unique refresh token for that user. This function must never return the same value twice. Returns: - tuple of (user_id, refresh_token) + tuple of (user_id, new_refresh_token, device_id) Raises: StoreError if no user was found with that refresh token. """ @@ -314,12 +320,13 @@ class RegistrationStore(SQLBaseStore): ) def _exchange_refresh_token(self, txn, old_token, token_generator): - sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" + sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?" txn.execute(sql, (old_token,)) rows = self.cursor_to_dict(txn) if not rows: raise StoreError(403, "Did not recognize refresh token") user_id = rows[0]["user_id"] + device_id = rows[0]["device_id"] # TODO(danielwh): Maybe perform a validation on the macaroon that # macaroon.user_id == user_id. @@ -328,7 +335,7 @@ class RegistrationStore(SQLBaseStore): sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" txn.execute(sql, (new_token, old_token,)) - return user_id, new_token + return user_id, new_token, device_id @defer.inlineCallbacks def is_server_admin(self, user): @@ -356,7 +363,8 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id" + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/schema/delta/33/devices.sql new file mode 100644 index 0000000000..eca7268d82 --- /dev/null +++ b/synapse/storage/schema/delta/33/devices.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE devices ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + display_name TEXT, + CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) +); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/33/refreshtoken_device.sql new file mode 100644 index 0000000000..b21da00dde --- /dev/null +++ b/synapse/storage/schema/delta/33/refreshtoken_device.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE refresh_tokens ADD COLUMN device_id BIGINT; diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py new file mode 100644 index 0000000000..cc6512ccc7 --- /dev/null +++ b/tests/handlers/test_device.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.handlers.device import DeviceHandler +from tests import unittest +from tests.utils import setup_test_homeserver + + +class DeviceHandlers(object): + def __init__(self, hs): + self.device_handler = DeviceHandler(hs) + + +class DeviceTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver(handlers=None) + self.hs.handlers = handlers = DeviceHandlers(self.hs) + self.handler = handlers.device_handler + + @defer.inlineCallbacks + def test_device_is_created_if_doesnt_exist(self): + res = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="display name" + ) + self.assertEqual(res, "fco") + + dev = yield self.handler.store.get_device("boris", "fco") + self.assertEqual(dev["display_name"], "display name") + + @defer.inlineCallbacks + def test_device_is_preserved_if_exists(self): + res1 = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="display name" + ) + self.assertEqual(res1, "fco") + + res2 = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="new display name" + ) + self.assertEqual(res2, "fco") + + dev = yield self.handler.store.get_device("boris", "fco") + self.assertEqual(dev["display_name"], "display name") + + @defer.inlineCallbacks + def test_device_id_is_made_up_if_unspecified(self): + device_id = yield self.handler.check_device_registered( + user_id="theresa", + device_id=None, + initial_device_display_name="display" + ) + + dev = yield self.handler.store.get_device("theresa", device_id) + self.assertEqual(dev["display_name"], "display") diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index b8384c98d8..b03ca303a2 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -38,6 +38,7 @@ class RegistrationStoreTestCase(unittest.TestCase): "BcDeFgHiJkLmNoPqRsTuVwXyZa" ] self.pwhash = "{xx1}123456789" + self.device_id = "akgjhdjklgshg" @defer.inlineCallbacks def test_register(self): @@ -64,13 +65,15 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_add_tokens(self): yield self.store.register(self.user_id, self.tokens[0], self.pwhash) - yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) + yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], + self.device_id) result = yield self.store.get_user_by_access_token(self.tokens[1]) self.assertDictContainsSubset( { "name": self.user_id, + "device_id": self.device_id, }, result ) @@ -80,20 +83,24 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_exchange_refresh_token_valid(self): uid = stringutils.random_string(32) + device_id = stringutils.random_string(16) generator = TokenGenerator() last_token = generator.generate(uid) self.db_pool.runQuery( - "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", - (uid, last_token,)) + "INSERT INTO refresh_tokens(user_id, token, device_id) " + "VALUES(?,?,?)", + (uid, last_token, device_id)) - (found_user_id, refresh_token) = yield self.store.exchange_refresh_token( - last_token, generator.generate) + (found_user_id, refresh_token, device_id) = \ + yield self.store.exchange_refresh_token(last_token, + generator.generate) self.assertEqual(uid, found_user_id) rows = yield self.db_pool.runQuery( - "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, )) - self.assertEqual([(refresh_token,)], rows) + "SELECT token, device_id FROM refresh_tokens WHERE user_id = ?", + (uid, )) + self.assertEqual([(refresh_token, device_id)], rows) # We issued token 1, then exchanged it for token 2 expected_refresh_token = u"%s-%d" % (uid, 2,) self.assertEqual(expected_refresh_token, refresh_token) -- cgit 1.4.1 From 0da0d0a29d807c481152b1580acbbe36f24cf771 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 19 Jul 2016 13:12:22 +0100 Subject: rest/client/v2_alpha/register.py: Refactor flow somewhat. This is meant to be an *almost* non-functional change, with the exception that it fixes what looks a lot like a bug in that it only calls `auth_handler.add_threepid` and `add_pusher` once instead of three times. The idea is to move the generation of the `access_token` out of `registration_handler.register`, because `access_token`s now require a device_id, and we only want to generate a device_id once registration has been successful. --- synapse/rest/client/v2_alpha/register.py | 177 ++++++++++++++++------------ tests/rest/client/v2_alpha/test_register.py | 3 +- 2 files changed, 104 insertions(+), 76 deletions(-) (limited to 'tests') diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index e8d34b06b0..707bde0f34 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -199,92 +199,55 @@ class RegisterRestServlet(RestServlet): "Already registered user ID %r for this session", registered_user_id ) - access_token = yield self.auth_handler.issue_access_token(registered_user_id) - refresh_token = yield self.auth_handler.issue_refresh_token( - registered_user_id + # don't re-register the email address + add_email = False + else: + # NB: This may be from the auth handler and NOT from the POST + if 'password' not in params: + raise SynapseError(400, "Missing password.", + Codes.MISSING_PARAM) + + desired_username = params.get("username", None) + new_password = params.get("password", None) + guest_access_token = params.get("guest_access_token", None) + + (registered_user_id, _) = yield self.registration_handler.register( + localpart=desired_username, + password=new_password, + guest_access_token=guest_access_token, + generate_token=False, ) - defer.returnValue((200, { - "user_id": registered_user_id, - "access_token": access_token, - "home_server": self.hs.hostname, - "refresh_token": refresh_token, - })) - # NB: This may be from the auth handler and NOT from the POST - if 'password' not in params: - raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM) + # remember that we've now registered that user account, and with + # what user ID (since the user may not have specified) + self.auth_handler.set_session_data( + session_id, "registered_user_id", registered_user_id + ) - desired_username = params.get("username", None) - new_password = params.get("password", None) - guest_access_token = params.get("guest_access_token", None) + add_email = True - (user_id, token) = yield self.registration_handler.register( - localpart=desired_username, - password=new_password, - guest_access_token=guest_access_token, + access_token = yield self.auth_handler.issue_access_token( + registered_user_id ) - # remember that we've now registered that user account, and with what - # user ID (since the user may not have specified) - self.auth_handler.set_session_data( - session_id, "registered_user_id", user_id - ) - - if result and LoginType.EMAIL_IDENTITY in result: + if add_email and result and LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] - - for reqd in ['medium', 'address', 'validated_at']: - if reqd not in threepid: - logger.info("Can't add incomplete 3pid") - else: - yield self.auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], - ) - - # And we add an email pusher for them by default, but only - # if email notifications are enabled (so people don't start - # getting mail spam where they weren't before if email - # notifs are set up on a home server) - if ( - self.hs.config.email_enable_notifs and - self.hs.config.email_notif_for_new_users - ): - # Pull the ID of the access token back out of the db - # It would really make more sense for this to be passed - # up when the access token is saved, but that's quite an - # invasive change I'd rather do separately. - user_tuple = yield self.store.get_user_by_access_token( - token - ) - - yield self.hs.get_pusherpool().add_pusher( - user_id=user_id, - access_token=user_tuple["token_id"], - kind="email", - app_id="m.email", - app_display_name="Email Notifications", - device_display_name=threepid["address"], - pushkey=threepid["address"], - lang=None, # We don't know a user's language here - data={}, - ) - - if 'bind_email' in params and params['bind_email']: + reqd = ('medium', 'address', 'validated_at') + if all(x in threepid for x in reqd): + yield self._register_email_threepid( + registered_user_id, threepid, access_token + ) + # XXX why is bind_email not protected by this? + else: + logger.info("Can't add incomplete 3pid") + if params.get("bind_email"): logger.info("bind_email specified: binding") - - emailThreepid = result[LoginType.EMAIL_IDENTITY] - threepid_creds = emailThreepid['threepid_creds'] - logger.debug("Binding emails %s to %s" % ( - emailThreepid, user_id - )) - yield self.identity_handler.bind_threepid(threepid_creds, user_id) + yield self._bind_email(registered_user_id, threepid) else: logger.info("bind_email not specified: not binding email") - result = yield self._create_registration_details(user_id, token) + result = yield self._create_registration_details(registered_user_id, + access_token) defer.returnValue((200, result)) def on_OPTIONS(self, _): @@ -324,6 +287,70 @@ class RegisterRestServlet(RestServlet): ) defer.returnValue((yield self._create_registration_details(user_id, token))) + @defer.inlineCallbacks + def _register_email_threepid(self, user_id, threepid, token): + """Add an email address as a 3pid identifier + + Also adds an email pusher for the email address, if configured in the + HS config + + Args: + user_id (str): id of user + threepid (object): m.login.email.identity auth response + token (str): access_token for the user + Returns: + defer.Deferred: + """ + yield self.auth_handler.add_threepid( + user_id, + threepid['medium'], + threepid['address'], + threepid['validated_at'], + ) + + # And we add an email pusher for them by default, but only + # if email notifications are enabled (so people don't start + # getting mail spam where they weren't before if email + # notifs are set up on a home server) + if (self.hs.config.email_enable_notifs and + self.hs.config.email_notif_for_new_users): + # Pull the ID of the access token back out of the db + # It would really make more sense for this to be passed + # up when the access token is saved, but that's quite an + # invasive change I'd rather do separately. + user_tuple = yield self.store.get_user_by_access_token( + token + ) + token_id = user_tuple["token_id"] + + yield self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="email", + app_id="m.email", + app_display_name="Email Notifications", + device_display_name=threepid["address"], + pushkey=threepid["address"], + lang=None, # We don't know a user's language here + data={}, + ) + defer.returnValue() + + def _bind_email(self, user_id, email_threepid): + """Bind emails to the given user_id on the identity server + + Args: + user_id (str): user id to bind the emails to + email_threepid (object): m.login.email.identity auth response + Returns: + defer.Deferred: + """ + threepid_creds = email_threepid['threepid_creds'] + logger.debug("Binding emails %s to %s" % ( + email_threepid, user_id + )) + return self.identity_handler.bind_threepid(threepid_creds, user_id) + @defer.inlineCallbacks def _create_registration_details(self, user_id, token): refresh_token = yield self.auth_handler.issue_refresh_token(user_id) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index cda0a2b27c..9a4215fef7 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -114,7 +114,8 @@ class RegisterRestServletTestCase(unittest.TestCase): "username": "kermit", "password": "monkey" }, None) - self.registration_handler.register = Mock(return_value=(user_id, token)) + self.registration_handler.register = Mock(return_value=(user_id, None)) + self.auth_handler.issue_access_token = Mock(return_value=token) (code, result) = yield self.servlet.on_POST(self.request) self.assertEquals(code, 200) -- cgit 1.4.1 From 40cbffb2d2ca0166f1377ac4ec5988046ea4ca10 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 19 Jul 2016 18:46:19 +0100 Subject: Further registration refactoring * `RegistrationHandler.appservice_register` no longer issues an access token: instead it is left for the caller to do it. (There are two of these, one in `synapse/rest/client/v1/register.py`, which now simply calls `AuthHandler.issue_access_token`, and the other in `synapse/rest/client/v2_alpha/register.py`, which is covered below). * In `synapse/rest/client/v2_alpha/register.py`, move the generation of access_tokens into `_create_registration_details`. This means that the normal flow no longer needs to call `AuthHandler.issue_access_token`; the shared-secret flow can tell `RegistrationHandler.register` not to generate a token; and the appservice flow continues to work despite the above change. --- synapse/handlers/register.py | 13 +++++--- synapse/rest/client/v1/register.py | 4 ++- synapse/rest/client/v2_alpha/register.py | 50 +++++++++++++++++++++-------- synapse/storage/registration.py | 6 ++-- tests/rest/client/v2_alpha/test_register.py | 6 +++- 5 files changed, 57 insertions(+), 22 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 6b33b27149..94b19d0cb0 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -99,8 +99,13 @@ class RegistrationHandler(BaseHandler): localpart : The local part of the user ID to register. If None, one will be generated. password (str) : The password to assign to this user so they can - login again. This can be None which means they cannot login again - via a password (e.g. the user is an application service user). + login again. This can be None which means they cannot login again + via a password (e.g. the user is an application service user). + generate_token (bool): Whether a new access token should be + generated. Having this be True should be considered deprecated, + since it offers no means of associating a device_id with the + access_token. Instead you should call auth_handler.issue_access_token + after registration. Returns: A tuple of (user_id, access_token). Raises: @@ -196,15 +201,13 @@ class RegistrationHandler(BaseHandler): user_id, allowed_appservice=service ) - token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, - token=token, password_hash="", appservice_id=service_id, create_profile_with_localpart=user.localpart, ) - defer.returnValue((user_id, token)) + defer.returnValue(user_id) @defer.inlineCallbacks def check_recaptcha(self, ip, private_key, challenge, response): diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 8e1f1b7845..28b59952c3 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -60,6 +60,7 @@ class RegisterRestServlet(ClientV1RestServlet): # TODO: persistent storage self.sessions = {} self.enable_registration = hs.config.enable_registration + self.auth_handler = hs.get_auth_handler() def on_GET(self, request): if self.hs.config.enable_registration_captcha: @@ -299,9 +300,10 @@ class RegisterRestServlet(ClientV1RestServlet): user_localpart = register_json["user"].encode("utf-8") handler = self.handlers.registration_handler - (user_id, token) = yield handler.appservice_register( + user_id = yield handler.appservice_register( user_localpart, as_token ) + token = yield self.auth_handler.issue_access_token(user_id) self._remove_session(session) defer.returnValue({ "user_id": user_id, diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 5db953a1e3..04004cfbbd 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -226,19 +226,17 @@ class RegisterRestServlet(RestServlet): add_email = True - access_token = yield self.auth_handler.issue_access_token( + result = yield self._create_registration_details( registered_user_id ) if add_email and result and LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] yield self._register_email_threepid( - registered_user_id, threepid, access_token, + registered_user_id, threepid, result["access_token"], params.get("bind_email") ) - result = yield self._create_registration_details(registered_user_id, - access_token) defer.returnValue((200, result)) def on_OPTIONS(self, _): @@ -246,10 +244,10 @@ class RegisterRestServlet(RestServlet): @defer.inlineCallbacks def _do_appservice_registration(self, username, as_token): - (user_id, token) = yield self.registration_handler.appservice_register( + user_id = yield self.registration_handler.appservice_register( username, as_token ) - defer.returnValue((yield self._create_registration_details(user_id, token))) + defer.returnValue((yield self._create_registration_details(user_id))) @defer.inlineCallbacks def _do_shared_secret_registration(self, username, password, mac): @@ -273,10 +271,12 @@ class RegisterRestServlet(RestServlet): 403, "HMAC incorrect", ) - (user_id, token) = yield self.registration_handler.register( - localpart=username, password=password + (user_id, _) = yield self.registration_handler.register( + localpart=username, password=password, generate_token=False, ) - defer.returnValue((yield self._create_registration_details(user_id, token))) + + result = yield self._create_registration_details(user_id) + defer.returnValue(result) @defer.inlineCallbacks def _register_email_threepid(self, user_id, threepid, token, bind_email): @@ -349,11 +349,31 @@ class RegisterRestServlet(RestServlet): defer.returnValue() @defer.inlineCallbacks - def _create_registration_details(self, user_id, token): - refresh_token = yield self.auth_handler.issue_refresh_token(user_id) + def _create_registration_details(self, user_id): + """Complete registration of newly-registered user + + Issues access_token and refresh_token, and builds the success response + body. + + Args: + (str) user_id: full canonical @user:id + + + Returns: + defer.Deferred: (object) dictionary for response from /register + """ + + access_token = yield self.auth_handler.issue_access_token( + user_id + ) + + refresh_token = yield self.auth_handler.issue_refresh_token( + user_id + ) + defer.returnValue({ "user_id": user_id, - "access_token": token, + "access_token": access_token, "home_server": self.hs.hostname, "refresh_token": refresh_token, }) @@ -366,7 +386,11 @@ class RegisterRestServlet(RestServlet): generate_token=False, make_guest=True ) - access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"]) + access_token = self.auth_handler.generate_access_token( + user_id, ["guest = true"] + ) + # XXX the "guest" caveat is not copied by /tokenrefresh. That's ok + # so long as we don't return a refresh_token here. defer.returnValue((200, { "user_id": user_id, "access_token": access_token, diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 26ef1cfd8a..9a92b35361 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -81,14 +81,16 @@ class RegistrationStore(SQLBaseStore): ) @defer.inlineCallbacks - def register(self, user_id, token, password_hash, + def register(self, user_id, token=None, password_hash=None, was_guest=False, make_guest=False, appservice_id=None, create_profile_with_localpart=None, admin=False): """Attempts to register an account. Args: user_id (str): The desired user ID to register. - token (str): The desired access token to use for this user. + token (str): The desired access token to use for this user. If this + is not None, the given access token is associated with the user + id. password_hash (str): Optional. The password hash for this user. was_guest (bool): Optional. Whether this is a guest account being upgraded to a non-guest account. diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 9a4215fef7..ccbb8776d3 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -61,8 +61,10 @@ class RegisterRestServletTestCase(unittest.TestCase): "id": "1234" } self.registration_handler.appservice_register = Mock( - return_value=(user_id, token) + return_value=user_id ) + self.auth_handler.issue_access_token = Mock(return_value=token) + (code, result) = yield self.servlet.on_POST(self.request) self.assertEquals(code, 200) det_data = { @@ -126,6 +128,8 @@ class RegisterRestServletTestCase(unittest.TestCase): } self.assertDictContainsSubset(det_data, result) self.assertIn("refresh_token", result) + self.auth_handler.issue_access_token.assert_called_once_with( + user_id) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False -- cgit 1.4.1 From b97a1356b149f62e5b2c28b09818d74b445cc635 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 19 Jul 2016 18:38:26 +0100 Subject: Register a device_id in the /v2/register flow. This doesn't cover *all* of the registration flows, but it does cover the most common ones: in particular: shared_secret registration, appservice registration, and normal user/pass registration. Pull device_id from the registration parameters. Register the device in the devices table. Associate the device with the returned access and refresh tokens. Profit. --- synapse/rest/client/v2_alpha/register.py | 54 +++++++++++++++++++++-------- tests/rest/client/v2_alpha/test_register.py | 13 +++++-- 2 files changed, 49 insertions(+), 18 deletions(-) (limited to 'tests') diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b7e03ea9d1..d401722224 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -93,6 +93,7 @@ class RegisterRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def on_POST(self, request): @@ -145,7 +146,7 @@ class RegisterRestServlet(RestServlet): if isinstance(desired_username, basestring): result = yield self._do_appservice_registration( - desired_username, request.args["access_token"][0] + desired_username, request.args["access_token"][0], body ) defer.returnValue((200, result)) # we throw for non 200 responses return @@ -155,7 +156,7 @@ class RegisterRestServlet(RestServlet): # FIXME: Should we really be determining if this is shared secret # auth based purely on the 'mac' key? result = yield self._do_shared_secret_registration( - desired_username, desired_password, body["mac"] + desired_username, desired_password, body ) defer.returnValue((200, result)) # we throw for non 200 responses return @@ -236,7 +237,7 @@ class RegisterRestServlet(RestServlet): add_email = True result = yield self._create_registration_details( - registered_user_id + registered_user_id, body ) if add_email and result and LoginType.EMAIL_IDENTITY in result: @@ -252,14 +253,14 @@ class RegisterRestServlet(RestServlet): return 200, {} @defer.inlineCallbacks - def _do_appservice_registration(self, username, as_token): + def _do_appservice_registration(self, username, as_token, body): user_id = yield self.registration_handler.appservice_register( username, as_token ) - defer.returnValue((yield self._create_registration_details(user_id))) + defer.returnValue((yield self._create_registration_details(user_id, body))) @defer.inlineCallbacks - def _do_shared_secret_registration(self, username, password, mac): + def _do_shared_secret_registration(self, username, password, body): if not self.hs.config.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") @@ -267,7 +268,7 @@ class RegisterRestServlet(RestServlet): # str() because otherwise hmac complains that 'unicode' does not # have the buffer interface - got_mac = str(mac) + got_mac = str(body["mac"]) want_mac = hmac.new( key=self.hs.config.registration_shared_secret, @@ -284,7 +285,7 @@ class RegisterRestServlet(RestServlet): localpart=username, password=password, generate_token=False, ) - result = yield self._create_registration_details(user_id) + result = yield self._create_registration_details(user_id, body) defer.returnValue(result) @defer.inlineCallbacks @@ -358,35 +359,58 @@ class RegisterRestServlet(RestServlet): defer.returnValue() @defer.inlineCallbacks - def _create_registration_details(self, user_id): + def _create_registration_details(self, user_id, body): """Complete registration of newly-registered user - Issues access_token and refresh_token, and builds the success response - body. + Allocates device_id if one was not given; also creates access_token + and refresh_token. Args: (str) user_id: full canonical @user:id - + (object) body: dictionary supplied to /register call, from + which we pull device_id and initial_device_name Returns: defer.Deferred: (object) dictionary for response from /register """ + device_id = yield self._register_device(user_id, body) access_token = yield self.auth_handler.issue_access_token( - user_id + user_id, device_id=device_id ) refresh_token = yield self.auth_handler.issue_refresh_token( - user_id + user_id, device_id=device_id ) - defer.returnValue({ "user_id": user_id, "access_token": access_token, "home_server": self.hs.hostname, "refresh_token": refresh_token, + "device_id": device_id, }) + def _register_device(self, user_id, body): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) body: dictionary supplied to /register call, from + which we pull device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + # register the user's device + device_id = body.get("device_id") + initial_display_name = body.get("initial_device_display_name") + device_id = self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + return device_id + @defer.inlineCallbacks def _do_guest_registration(self): if not self.hs.config.allow_guest_access: diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index ccbb8776d3..3bd7065e32 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -30,6 +30,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.registration_handler = Mock() self.identity_handler = Mock() self.login_handler = Mock() + self.device_handler = Mock() # do the dance to hook it up to the hs global self.handlers = Mock( @@ -42,6 +43,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_auth_handler = Mock(return_value=self.auth_handler) + self.hs.get_device_handler = Mock(return_value=self.device_handler) self.hs.config.enable_registration = True # init the thing we're testing @@ -107,9 +109,11 @@ class RegisterRestServletTestCase(unittest.TestCase): def test_POST_user_valid(self): user_id = "@kermit:muppet" token = "kermits_access_token" + device_id = "frogfone" self.request_data = json.dumps({ "username": "kermit", - "password": "monkey" + "password": "monkey", + "device_id": device_id, }) self.registration_handler.check_username = Mock(return_value=True) self.auth_result = (True, None, { @@ -118,18 +122,21 @@ class RegisterRestServletTestCase(unittest.TestCase): }, None) self.registration_handler.register = Mock(return_value=(user_id, None)) self.auth_handler.issue_access_token = Mock(return_value=token) + self.device_handler.check_device_registered = \ + Mock(return_value=device_id) (code, result) = yield self.servlet.on_POST(self.request) self.assertEquals(code, 200) det_data = { "user_id": user_id, "access_token": token, - "home_server": self.hs.hostname + "home_server": self.hs.hostname, + "device_id": device_id, } self.assertDictContainsSubset(det_data, result) self.assertIn("refresh_token", result) self.auth_handler.issue_access_token.assert_called_once_with( - user_id) + user_id, device_id=device_id) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False -- cgit 1.4.1 From ec041b335ecb20008609c8603338ab8c586615be Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 20 Jul 2016 15:25:40 +0100 Subject: Record device_id in client_ips Record the device_id when we add a client ip; it's somewhat redundant as we could get it via the access_token, but it will make querying rather easier. --- synapse/api/auth.py | 29 +++++++++++++++++++++++------ synapse/storage/client_ips.py | 3 ++- tests/api/test_auth.py | 10 +++++++++- 3 files changed, 34 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index ff7d816cfc..eca8513905 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -586,6 +586,10 @@ class Auth(object): token_id = user_info["token_id"] is_guest = user_info["is_guest"] + # device_id may not be present if get_user_by_access_token has been + # stubbed out. + device_id = user_info.get("device_id") + ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( "User-Agent", @@ -597,7 +601,8 @@ class Auth(object): user=user, access_token=access_token, ip=ip_addr, - user_agent=user_agent + user_agent=user_agent, + device_id=device_id, ) if is_guest and not allow_guest: @@ -695,6 +700,7 @@ class Auth(object): "user": user, "is_guest": True, "token_id": None, + "device_id": None, } elif rights == "delete_pusher": # We don't store these tokens in the database @@ -702,13 +708,20 @@ class Auth(object): "user": user, "is_guest": False, "token_id": None, + "device_id": None, } else: - # This codepath exists so that we can actually return a - # token ID, because we use token IDs in place of device - # identifiers throughout the codebase. - # TODO(daniel): Remove this fallback when device IDs are - # properly implemented. + # This codepath exists for several reasons: + # * so that we can actually return a token ID, which is used + # in some parts of the schema (where we probably ought to + # use device IDs instead) + # * the only way we currently have to invalidate an + # access_token is by removing it from the database, so we + # have to check here that it is still in the db + # * some attributes (notably device_id) aren't stored in the + # macaroon. They probably should be. + # TODO: build the dictionary from the macaroon once the + # above are fixed ret = yield self._look_up_user_by_access_token(macaroon_str) if ret["user"] != user: logger.error( @@ -782,10 +795,14 @@ class Auth(object): self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", errcode=Codes.UNKNOWN_TOKEN ) + # we use ret.get() below because *lots* of unit tests stub out + # get_user_by_access_token in a way where it only returns a couple of + # the fields. user_info = { "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), "is_guest": False, + "device_id": ret.get("device_id"), } defer.returnValue(user_info) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index a90990e006..74330a8ddf 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -35,7 +35,7 @@ class ClientIpStore(SQLBaseStore): super(ClientIpStore, self).__init__(hs) @defer.inlineCallbacks - def insert_client_ip(self, user, access_token, ip, user_agent): + def insert_client_ip(self, user, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) key = (user.to_string(), access_token, ip) @@ -59,6 +59,7 @@ class ClientIpStore(SQLBaseStore): "access_token": access_token, "ip": ip, "user_agent": user_agent, + "device_id": device_id, }, values={ "last_seen": now, diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 960c23d631..e91723ca3d 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -45,6 +45,7 @@ class AuthTestCase(unittest.TestCase): user_info = { "name": self.test_user, "token_id": "ditto", + "device_id": "device", } self.store.get_user_by_access_token = Mock(return_value=user_info) @@ -143,7 +144,10 @@ class AuthTestCase(unittest.TestCase): # TODO(danielwh): Remove this mock when we remove the # get_user_by_access_token fallback. self.store.get_user_by_access_token = Mock( - return_value={"name": "@baldrick:matrix.org"} + return_value={ + "name": "@baldrick:matrix.org", + "device_id": "device", + } ) user_id = "@baldrick:matrix.org" @@ -158,6 +162,10 @@ class AuthTestCase(unittest.TestCase): user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) + # TODO: device_id should come from the macaroon, but currently comes + # from the db. + self.assertEqual(user_info["device_id"], "device") + @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): user_id = "@baldrick:matrix.org" -- cgit 1.4.1 From bc8f265f0a8443e918b17a94f4b2fa319e70a21f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 20 Jul 2016 16:34:00 +0100 Subject: GET /devices endpoint implement a GET /devices endpoint which lists all of the user's devices. It also returns the last IP where we saw that device, so there is some dancing to fish that out of the user_ips table. --- synapse/handlers/device.py | 27 ++++++++ synapse/rest/__init__.py | 2 + synapse/rest/client/v2_alpha/_base.py | 13 ++-- synapse/rest/client/v2_alpha/devices.py | 51 ++++++++++++++ synapse/storage/client_ips.py | 72 ++++++++++++++++++++ synapse/storage/devices.py | 22 +++++- synapse/storage/schema/delta/33/user_ips_index.sql | 16 +++++ tests/handlers/test_device.py | 78 ++++++++++++++++++---- tests/storage/test_client_ips.py | 62 +++++++++++++++++ tests/storage/test_devices.py | 71 ++++++++++++++++++++ 10 files changed, 397 insertions(+), 17 deletions(-) create mode 100644 synapse/rest/client/v2_alpha/devices.py create mode 100644 synapse/storage/schema/delta/33/user_ips_index.sql create mode 100644 tests/storage/test_client_ips.py create mode 100644 tests/storage/test_devices.py (limited to 'tests') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 8d7d9874f8..6bbbf59e52 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -69,3 +69,30 @@ class DeviceHandler(BaseHandler): attempts += 1 raise StoreError(500, "Couldn't generate a device ID.") + + @defer.inlineCallbacks + def get_devices_by_user(self, user_id): + """ + Retrieve the given user's devices + + Args: + user_id (str): + Returns: + defer.Deferred: dict[str, dict[str, X]]: map from device_id to + info on the device + """ + + devices = yield self.store.get_devices_by_user(user_id) + + ips = yield self.store.get_last_client_ip_by_device( + devices=((user_id, device_id) for device_id in devices.keys()) + ) + + for device_id in devices.keys(): + ip = ips.get((user_id, device_id), {}) + devices[device_id].update({ + "last_seen_ts": ip.get("last_seen"), + "last_seen_ip": ip.get("ip"), + }) + + defer.returnValue(devices) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 8b223e032b..14227f1cdb 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -46,6 +46,7 @@ from synapse.rest.client.v2_alpha import ( account_data, report_event, openid, + devices, ) from synapse.http.server import JsonResource @@ -90,3 +91,4 @@ class ClientRestResource(JsonResource): account_data.register_servlets(hs, client_resource) report_event.register_servlets(hs, client_resource) openid.register_servlets(hs, client_resource) + devices.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index b6faa2b0e6..20e765f48f 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -25,7 +25,9 @@ import logging logger = logging.getLogger(__name__) -def client_v2_patterns(path_regex, releases=(0,)): +def client_v2_patterns(path_regex, releases=(0,), + v2_alpha=True, + unstable=True): """Creates a regex compiled client path with the correct client path prefix. @@ -35,9 +37,12 @@ def client_v2_patterns(path_regex, releases=(0,)): Returns: SRE_Pattern """ - patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)] - unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") - patterns.append(re.compile("^" + unstable_prefix + path_regex)) + patterns = [] + if v2_alpha: + patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)) + if unstable: + unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") + patterns.append(re.compile("^" + unstable_prefix + path_regex)) for release in releases: new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) patterns.append(re.compile("^" + new_prefix + path_regex)) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py new file mode 100644 index 0000000000..5cf8bd1afa --- /dev/null +++ b/synapse/rest/client/v2_alpha/devices.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.http.servlet import RestServlet + +from ._base import client_v2_patterns + +import logging + + +logger = logging.getLogger(__name__) + + +class DevicesRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(DevicesRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + devices = yield self.device_handler.get_devices_by_user( + requester.user.to_string() + ) + defer.returnValue((200, {"devices": devices})) + + +def register_servlets(hs, http_server): + DevicesRestServlet(hs).register(http_server) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index a90990e006..07161496ca 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -13,10 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from ._base import SQLBaseStore, Cache from twisted.internet import defer +logger = logging.getLogger(__name__) # Number of msec of granularity to store the user IP 'last seen' time. Smaller # times give more inserts into the database even for readonly API hits @@ -66,3 +69,72 @@ class ClientIpStore(SQLBaseStore): desc="insert_client_ip", lock=False, ) + + @defer.inlineCallbacks + def get_last_client_ip_by_device(self, devices): + """For each device_id listed, give the user_ip it was last seen on + + Args: + devices (iterable[(str, str)]): list of (user_id, device_id) pairs + + Returns: + defer.Deferred: resolves to a dict, where the keys + are (user_id, device_id) tuples. The values are also dicts, with + keys giving the column names + """ + + res = yield self.runInteraction( + "get_last_client_ip_by_device", + self._get_last_client_ip_by_device_txn, + retcols=( + "user_id", + "access_token", + "ip", + "user_agent", + "device_id", + "last_seen", + ), + devices=devices + ) + + ret = {(d["user_id"], d["device_id"]): d for d in res} + defer.returnValue(ret) + + @classmethod + def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols): + def where_clause_for_device(d): + return + + where_clauses = [] + bindings = [] + for (user_id, device_id) in devices: + if device_id is None: + where_clauses.append("(user_id = ? AND device_id IS NULL)") + bindings.extend((user_id, )) + else: + where_clauses.append("(user_id = ? AND device_id = ?)") + bindings.extend((user_id, device_id)) + + inner_select = ( + "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips " + "WHERE %(where)s " + "GROUP BY user_id, device_id" + ) % { + "where": " OR ".join(where_clauses), + } + + sql = ( + "SELECT %(retcols)s FROM user_ips " + "JOIN (%(inner_select)s) ips ON" + " user_ips.last_seen = ips.mls AND" + " user_ips.user_id = ips.user_id AND" + " (user_ips.device_id = ips.device_id OR" + " (user_ips.device_id IS NULL AND ips.device_id IS NULL)" + " )" + ) % { + "retcols": ",".join("user_ips." + c for c in retcols), + "inner_select": inner_select, + } + + txn.execute(sql, bindings) + return cls.cursor_to_dict(txn) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 9065e96d28..1cc6e07f2b 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -65,7 +65,7 @@ class DeviceStore(SQLBaseStore): user_id (str): The ID of the user which owns the device device_id (str): The ID of the device to retrieve Returns: - defer.Deferred for a namedtuple containing the device information + defer.Deferred for a dict containing the device information Raises: StoreError: if the device is not found """ @@ -75,3 +75,23 @@ class DeviceStore(SQLBaseStore): retcols=("user_id", "device_id", "display_name"), desc="get_device", ) + + @defer.inlineCallbacks + def get_devices_by_user(self, user_id): + """Retrieve all of a user's registered devices. + + Args: + user_id (str): + Returns: + defer.Deferred: resolves to a dict from device_id to a dict + containing "device_id", "user_id" and "display_name" for each + device. + """ + devices = yield self._simple_select_list( + table="devices", + keyvalues={"user_id": user_id}, + retcols=("user_id", "device_id", "display_name"), + desc="get_devices_by_user" + ) + + defer.returnValue({d["device_id"]: d for d in devices}) diff --git a/synapse/storage/schema/delta/33/user_ips_index.sql b/synapse/storage/schema/delta/33/user_ips_index.sql new file mode 100644 index 0000000000..8a05677d42 --- /dev/null +++ b/synapse/storage/schema/delta/33/user_ips_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX user_ips_device_id ON user_ips(user_id, device_id, last_seen); diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index cc6512ccc7..c2e12135d6 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -12,25 +12,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from synapse import types from twisted.internet import defer -from synapse.handlers.device import DeviceHandler -from tests import unittest -from tests.utils import setup_test_homeserver - - -class DeviceHandlers(object): - def __init__(self, hs): - self.device_handler = DeviceHandler(hs) +import synapse.handlers.device +import synapse.storage +from tests import unittest, utils class DeviceTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(DeviceTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.DataStore + self.handler = None # type: device.DeviceHandler + self.clock = None # type: utils.MockClock + @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver(handlers=None) - self.hs.handlers = handlers = DeviceHandlers(self.hs) - self.handler = handlers.device_handler + hs = yield utils.setup_test_homeserver(handlers=None) + self.handler = synapse.handlers.device.DeviceHandler(hs) + self.store = hs.get_datastore() + self.clock = hs.get_clock() @defer.inlineCallbacks def test_device_is_created_if_doesnt_exist(self): @@ -73,3 +75,55 @@ class DeviceTestCase(unittest.TestCase): dev = yield self.handler.store.get_device("theresa", device_id) self.assertEqual(dev["display_name"], "display") + + @defer.inlineCallbacks + def test_get_devices_by_user(self): + # check this works for both devices which have a recorded client_ip, + # and those which don't. + user1 = "@boris:aaa" + user2 = "@theresa:bbb" + yield self._record_user(user1, "xyz", "display 0") + yield self._record_user(user1, "fco", "display 1", "token1", "ip1") + yield self._record_user(user1, "abc", "display 2", "token2", "ip2") + yield self._record_user(user1, "abc", "display 2", "token3", "ip3") + + yield self._record_user(user2, "def", "dispkay", "token4", "ip4") + + res = yield self.handler.get_devices_by_user(user1) + self.assertEqual(3, len(res.keys())) + self.assertDictContainsSubset({ + "user_id": user1, + "device_id": "xyz", + "display_name": "display 0", + "last_seen_ip": None, + "last_seen_ts": None, + }, res["xyz"]) + self.assertDictContainsSubset({ + "user_id": user1, + "device_id": "fco", + "display_name": "display 1", + "last_seen_ip": "ip1", + "last_seen_ts": 1000000, + }, res["fco"]) + self.assertDictContainsSubset({ + "user_id": user1, + "device_id": "abc", + "display_name": "display 2", + "last_seen_ip": "ip3", + "last_seen_ts": 3000000, + }, res["abc"]) + + @defer.inlineCallbacks + def _record_user(self, user_id, device_id, display_name, + access_token=None, ip=None): + device_id = yield self.handler.check_device_registered( + user_id=user_id, + device_id=device_id, + initial_device_display_name=display_name + ) + + if ip is not None: + yield self.store.insert_client_ip( + types.UserID.from_string(user_id), + access_token, ip, "user_agent", device_id) + self.clock.advance_time(1000) \ No newline at end of file diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py new file mode 100644 index 0000000000..1f0c0e7c37 --- /dev/null +++ b/tests/storage/test_client_ips.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import synapse.server +import synapse.storage +import synapse.types +import tests.unittest +import tests.utils + + +class ClientIpStoreTestCase(tests.unittest.TestCase): + def __init__(self, *args, **kwargs): + super(ClientIpStoreTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.DataStore + self.clock = None # type: tests.utils.MockClock + + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def test_insert_new_client_ip(self): + self.clock.now = 12345678 + user_id = "@user:id" + yield self.store.insert_client_ip( + synapse.types.UserID.from_string(user_id), + "access_token", "ip", "user_agent", "device_id", + ) + + # deliberately use an iterable here to make sure that the lookup + # method doesn't iterate it twice + device_list = iter(((user_id, "device_id"),)) + result = yield self.store.get_last_client_ip_by_device(device_list) + + r = result[(user_id, "device_id")] + self.assertDictContainsSubset( + { + "user_id": user_id, + "device_id": "device_id", + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "last_seen": 12345678000, + }, + r + ) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py new file mode 100644 index 0000000000..d3e9d97a9a --- /dev/null +++ b/tests/storage/test_devices.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import synapse.server +import synapse.types +import tests.unittest +import tests.utils + + +class DeviceStoreTestCase(tests.unittest.TestCase): + def __init__(self, *args, **kwargs): + super(DeviceStoreTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.DataStore + + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_store_new_device(self): + yield self.store.store_device( + "user_id", "device_id", "display_name" + ) + + res = yield self.store.get_device("user_id", "device_id") + self.assertDictContainsSubset({ + "user_id": "user_id", + "device_id": "device_id", + "display_name": "display_name", + }, res) + + @defer.inlineCallbacks + def test_get_devices_by_user(self): + yield self.store.store_device( + "user_id", "device1", "display_name 1" + ) + yield self.store.store_device( + "user_id", "device2", "display_name 2" + ) + yield self.store.store_device( + "user_id2", "device3", "display_name 3" + ) + + res = yield self.store.get_devices_by_user("user_id") + self.assertEqual(2, len(res.keys())) + self.assertDictContainsSubset({ + "user_id": "user_id", + "device_id": "device1", + "display_name": "display_name 1", + }, res["device1"]) + self.assertDictContainsSubset({ + "user_id": "user_id", + "device_id": "device2", + "display_name": "display_name 2", + }, res["device2"]) -- cgit 1.4.1 From 40a1c96617fd5926b53f8993bb93af159af4d674 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 20 Jul 2016 18:06:28 +0100 Subject: Fix PEP8 errors --- tests/handlers/test_device.py | 2 +- tests/storage/test_devices.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index c2e12135d6..b05aa9bb55 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -126,4 +126,4 @@ class DeviceTestCase(unittest.TestCase): yield self.store.insert_client_ip( types.UserID.from_string(user_id), access_token, ip, "user_agent", device_id) - self.clock.advance_time(1000) \ No newline at end of file + self.clock.advance_time(1000) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index d3e9d97a9a..a6ce993375 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -15,8 +15,6 @@ from twisted.internet import defer -import synapse.server -import synapse.types import tests.unittest import tests.utils -- cgit 1.4.1 From 406f7aa0f6ca7433e52433485824e80b79930498 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 20 Jul 2016 17:58:44 +0100 Subject: Implement GET /device/{deviceId} --- synapse/handlers/device.py | 46 ++++++++++++++++++++++++++------- synapse/rest/client/v2_alpha/devices.py | 25 ++++++++++++++++++ tests/handlers/test_device.py | 37 +++++++++++++++++++------- 3 files changed, 89 insertions(+), 19 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 6bbbf59e52..3c88be0679 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -12,7 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.errors import StoreError + +from synapse.api import errors from synapse.util import stringutils from twisted.internet import defer from ._base import BaseHandler @@ -65,10 +66,10 @@ class DeviceHandler(BaseHandler): ignore_if_known=False, ) defer.returnValue(device_id) - except StoreError: + except errors.StoreError: attempts += 1 - raise StoreError(500, "Couldn't generate a device ID.") + raise errors.StoreError(500, "Couldn't generate a device ID.") @defer.inlineCallbacks def get_devices_by_user(self, user_id): @@ -88,11 +89,38 @@ class DeviceHandler(BaseHandler): devices=((user_id, device_id) for device_id in devices.keys()) ) - for device_id in devices.keys(): - ip = ips.get((user_id, device_id), {}) - devices[device_id].update({ - "last_seen_ts": ip.get("last_seen"), - "last_seen_ip": ip.get("ip"), - }) + for device in devices.values(): + _update_device_from_client_ips(device, ips) defer.returnValue(devices) + + @defer.inlineCallbacks + def get_device(self, user_id, device_id): + """ Retrieve the given device + + Args: + user_id (str): + device_id (str) + + Returns: + defer.Deferred: dict[str, X]: info on the device + Raises: + errors.NotFoundError: if the device was not found + """ + try: + device = yield self.store.get_device(user_id, device_id) + except errors.StoreError, e: + raise errors.NotFoundError + ips = yield self.store.get_last_client_ip_by_device( + devices=((user_id, device_id),) + ) + _update_device_from_client_ips(device, ips) + defer.returnValue(device) + + +def _update_device_from_client_ips(device, client_ips): + ip = client_ips.get((device["user_id"], device["device_id"]), {}) + device.update({ + "last_seen_ts": ip.get("last_seen"), + "last_seen_ip": ip.get("ip"), + }) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 5cf8bd1afa..8b9ab4f674 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -47,5 +47,30 @@ class DevicesRestServlet(RestServlet): defer.returnValue((200, {"devices": devices})) +class DeviceRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/devices/(?P[^/]*)$", + releases=[], v2_alpha=False) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(DeviceRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def on_GET(self, request, device_id): + requester = yield self.auth.get_user_by_req(request) + device = yield self.device_handler.get_device( + requester.user.to_string(), + device_id, + ) + defer.returnValue((200, device)) + + def register_servlets(hs, http_server): DevicesRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index b05aa9bb55..73f09874d8 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -19,6 +19,8 @@ import synapse.handlers.device import synapse.storage from tests import unittest, utils +user1 = "@boris:aaa" +user2 = "@theresa:bbb" class DeviceTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): @@ -78,16 +80,7 @@ class DeviceTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_devices_by_user(self): - # check this works for both devices which have a recorded client_ip, - # and those which don't. - user1 = "@boris:aaa" - user2 = "@theresa:bbb" - yield self._record_user(user1, "xyz", "display 0") - yield self._record_user(user1, "fco", "display 1", "token1", "ip1") - yield self._record_user(user1, "abc", "display 2", "token2", "ip2") - yield self._record_user(user1, "abc", "display 2", "token3", "ip3") - - yield self._record_user(user2, "def", "dispkay", "token4", "ip4") + yield self._record_users() res = yield self.handler.get_devices_by_user(user1) self.assertEqual(3, len(res.keys())) @@ -113,6 +106,30 @@ class DeviceTestCase(unittest.TestCase): "last_seen_ts": 3000000, }, res["abc"]) + @defer.inlineCallbacks + def test_get_device(self): + yield self._record_users() + + res = yield self.handler.get_device(user1, "abc") + self.assertDictContainsSubset({ + "user_id": user1, + "device_id": "abc", + "display_name": "display 2", + "last_seen_ip": "ip3", + "last_seen_ts": 3000000, + }, res) + + @defer.inlineCallbacks + def _record_users(self): + # check this works for both devices which have a recorded client_ip, + # and those which don't. + yield self._record_user(user1, "xyz", "display 0") + yield self._record_user(user1, "fco", "display 1", "token1", "ip1") + yield self._record_user(user1, "abc", "display 2", "token2", "ip2") + yield self._record_user(user1, "abc", "display 2", "token3", "ip3") + + yield self._record_user(user2, "def", "dispkay", "token4", "ip4") + @defer.inlineCallbacks def _record_user(self, user_id, device_id, display_name, access_token=None, ip=None): -- cgit 1.4.1 From 1c3c202b969d6a7e5e4af2b2dca370f053b92c9f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 21 Jul 2016 13:15:15 +0100 Subject: Fix PEP8 errors --- synapse/handlers/device.py | 2 +- tests/handlers/test_device.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 3c88be0679..110f5fbb5c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -109,7 +109,7 @@ class DeviceHandler(BaseHandler): """ try: device = yield self.store.get_device(user_id, device_id) - except errors.StoreError, e: + except errors.StoreError: raise errors.NotFoundError ips = yield self.store.get_last_client_ip_by_device( devices=((user_id, device_id),) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 73f09874d8..87c3c75aea 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -22,6 +22,7 @@ from tests import unittest, utils user1 = "@boris:aaa" user2 = "@theresa:bbb" + class DeviceTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(DeviceTestCase, self).__init__(*args, **kwargs) -- cgit 1.4.1 From 55abbe1850efff95efe9935873b666e5fc4bf0e9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 21 Jul 2016 15:55:13 +0100 Subject: make /devices return a list Turns out I specced this to return a list of devices rather than a dict of them --- synapse/handlers/device.py | 10 +++++----- tests/handlers/test_device.py | 11 +++++++---- 2 files changed, 12 insertions(+), 9 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 110f5fbb5c..1f9e15c33c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -79,17 +79,17 @@ class DeviceHandler(BaseHandler): Args: user_id (str): Returns: - defer.Deferred: dict[str, dict[str, X]]: map from device_id to - info on the device + defer.Deferred: list[dict[str, X]]: info on each device """ - devices = yield self.store.get_devices_by_user(user_id) + device_map = yield self.store.get_devices_by_user(user_id) ips = yield self.store.get_last_client_ip_by_device( - devices=((user_id, device_id) for device_id in devices.keys()) + devices=((user_id, device_id) for device_id in device_map.keys()) ) - for device in devices.values(): + devices = device_map.values() + for device in devices: _update_device_from_client_ips(device, ips) defer.returnValue(devices) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 87c3c75aea..331aa13fed 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -84,28 +84,31 @@ class DeviceTestCase(unittest.TestCase): yield self._record_users() res = yield self.handler.get_devices_by_user(user1) - self.assertEqual(3, len(res.keys())) + self.assertEqual(3, len(res)) + device_map = { + d["device_id"]: d for d in res + } self.assertDictContainsSubset({ "user_id": user1, "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, "last_seen_ts": None, - }, res["xyz"]) + }, device_map["xyz"]) self.assertDictContainsSubset({ "user_id": user1, "device_id": "fco", "display_name": "display 1", "last_seen_ip": "ip1", "last_seen_ts": 1000000, - }, res["fco"]) + }, device_map["fco"]) self.assertDictContainsSubset({ "user_id": user1, "device_id": "abc", "display_name": "display 2", "last_seen_ip": "ip3", "last_seen_ts": 3000000, - }, res["abc"]) + }, device_map["abc"]) @defer.inlineCallbacks def test_get_device(self): -- cgit 1.4.1 From 465117d7ca40ba9906697aa023897798f7833830 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 25 Jul 2016 12:10:42 +0100 Subject: Fix background_update tests A bit of a cleanup for background_updates, and make sure that the real background updates have run before we start the unit tests, so that they don't interfere with the tests. --- synapse/storage/background_updates.py | 27 ++++++++++++++++++++------- tests/storage/test_background_update.py | 22 ++++++++++++++++------ 2 files changed, 36 insertions(+), 13 deletions(-) (limited to 'tests') diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 75951d0173..2771f7c3c1 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -88,10 +88,12 @@ class BackgroundUpdateStore(SQLBaseStore): @defer.inlineCallbacks def start_doing_background_updates(self): - while True: - if self._background_update_timer is not None: - return + assert(self._background_update_timer is not None, + "background updates already running") + + logger.info("Starting background schema updates") + while True: sleep = defer.Deferred() self._background_update_timer = self._clock.call_later( self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None @@ -102,7 +104,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_timer = None try: - result = yield self.do_background_update( + result = yield self.do_next_background_update( self.BACKGROUND_UPDATE_DURATION_MS ) except: @@ -113,11 +115,12 @@ class BackgroundUpdateStore(SQLBaseStore): "No more background updates to do." " Unscheduling background update task." ) - return + defer.returnValue() @defer.inlineCallbacks - def do_background_update(self, desired_duration_ms): - """Does some amount of work on a background update + def do_next_background_update(self, desired_duration_ms): + """Does some amount of work on the next queued background update + Args: desired_duration_ms(float): How long we want to spend updating. @@ -136,11 +139,21 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_queue.append(update['update_name']) if not self._background_update_queue: + # no work left to do defer.returnValue(None) + # pop from the front, and add back to the back update_name = self._background_update_queue.pop(0) self._background_update_queue.append(update_name) + res = yield self._do_background_update(update_name, desired_duration_ms) + defer.returnValue(res) + + @defer.inlineCallbacks + def _do_background_update(self, update_name, desired_duration_ms): + logger.info("Starting update batch on background update '%s'", + update_name) + update_handler = self._background_update_handlers[update_name] performance = self._background_update_performance.get(update_name) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 6e4d9b1373..4944cb0d2e 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -10,7 +10,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver() + hs = yield setup_test_homeserver() # type: synapse.server.HomeServer self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -20,11 +20,20 @@ class BackgroundUpdateTestCase(unittest.TestCase): "test_update", self.update_handler ) + # run the real background updates, to get them out the way + # (perhaps we should run them as part of the test HS setup, since we + # run all of the other schema setup stuff there?) + while True: + res = yield self.store.do_next_background_update(1000) + if res is None: + break + @defer.inlineCallbacks def test_do_background_update(self): desired_count = 1000 duration_ms = 42 + # first step: make a bit of progress @defer.inlineCallbacks def update(progress, count): self.clock.advance_time_msec(count * duration_ms) @@ -42,7 +51,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): yield self.store.start_background_update("test_update", {"my_key": 1}) self.update_handler.reset_mock() - result = yield self.store.do_background_update( + result = yield self.store.do_next_background_update( duration_ms * desired_count ) self.assertIsNotNone(result) @@ -50,24 +59,25 @@ class BackgroundUpdateTestCase(unittest.TestCase): {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE ) + # second step: complete the update @defer.inlineCallbacks def update(progress, count): yield self.store._end_background_update("test_update") defer.returnValue(count) self.update_handler.side_effect = update - self.update_handler.reset_mock() - result = yield self.store.do_background_update( - duration_ms * desired_count + result = yield self.store.do_next_background_update( + duration_ms * desired_count ) self.assertIsNotNone(result) self.update_handler.assert_called_once_with( {"my_key": 2}, desired_count ) + # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = yield self.store.do_background_update( + result = yield self.store.do_next_background_update( duration_ms * desired_count ) self.assertIsNone(result) -- cgit 1.4.1 From f16f0e169d30e6920b892ee772693199b16713fd Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 25 Jul 2016 12:12:47 +0100 Subject: Slightly saner logging for unittests 1. Give the handler used for logging in unit tests a formatter, so that the output is slightly more meaningful 2. Log some synapse.storage stuff, because it's useful. --- tests/unittest.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/unittest.py b/tests/unittest.py index 5b22abfe74..38715972dd 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -17,13 +17,18 @@ from twisted.trial import unittest import logging - # logging doesn't have a "don't log anything at all EVARRRR setting, # but since the highest value is 50, 1000000 should do ;) NEVER = 1000000 -logging.getLogger().addHandler(logging.StreamHandler()) +handler = logging.StreamHandler() +handler.setFormatter(logging.Formatter( + "%(levelname)s:%(name)s:%(message)s [%(pathname)s:%(lineno)d]" +)) +logging.getLogger().addHandler(handler) logging.getLogger().setLevel(NEVER) +logging.getLogger("synapse.storage.SQL").setLevel(NEVER) +logging.getLogger("synapse.storage.txn").setLevel(NEVER) def around(target): @@ -70,8 +75,6 @@ class TestCase(unittest.TestCase): return ret logging.getLogger().setLevel(level) - # Don't set SQL logging - logging.getLogger("synapse.storage").setLevel(old_level) return orig() def assertObjectHasAttributes(self, attrs, obj): -- cgit 1.4.1 From 42f4feb2b709671bb2dbbabfe1aad7e951479652 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 25 Jul 2016 12:25:06 +0100 Subject: PEP8 --- tests/storage/test_background_update.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 4944cb0d2e..1286b4ce2d 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -68,7 +68,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): self.update_handler.side_effect = update self.update_handler.reset_mock() result = yield self.store.do_next_background_update( - duration_ms * desired_count + duration_ms * desired_count ) self.assertIsNotNone(result) self.update_handler.assert_called_once_with( -- cgit 1.4.1 From 436bffd15fb8382a0d2dddd3c6f7a077ba751da2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 22 Jul 2016 14:52:53 +0100 Subject: Implement deleting devices --- synapse/handlers/auth.py | 22 ++++++++++++++++-- synapse/handlers/device.py | 27 +++++++++++++++++++++- synapse/rest/client/v1/login.py | 13 ++++++++--- synapse/rest/client/v2_alpha/devices.py | 14 +++++++++++ synapse/rest/client/v2_alpha/register.py | 10 ++++---- synapse/storage/devices.py | 15 ++++++++++++ synapse/storage/registration.py | 26 +++++++++++++++++---- .../schema/delta/33/access_tokens_device_index.sql | 17 ++++++++++++++ .../schema/delta/33/refreshtoken_device_index.sql | 17 ++++++++++++++ tests/handlers/test_device.py | 22 ++++++++++++++++-- tests/rest/client/v2_alpha/test_register.py | 14 +++++++---- 11 files changed, 176 insertions(+), 21 deletions(-) create mode 100644 synapse/storage/schema/delta/33/access_tokens_device_index.sql create mode 100644 synapse/storage/schema/delta/33/refreshtoken_device_index.sql (limited to 'tests') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d5d2072436..2e138f328f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -77,6 +77,7 @@ class AuthHandler(BaseHandler): self.ldap_bind_password = hs.config.ldap_bind_password self.hs = hs # FIXME better possibility to access registrationHandler later? + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -374,7 +375,8 @@ class AuthHandler(BaseHandler): return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id, device_id=None): + def get_login_tuple_for_user_id(self, user_id, device_id=None, + initial_display_name=None): """ Gets login tuple for the user with the given user ID. @@ -383,9 +385,15 @@ class AuthHandler(BaseHandler): The user is assumed to have been authenticated by some other machanism (e.g. CAS), and the user_id converted to the canonical case. + The device will be recorded in the table if it is not there already. + Args: user_id (str): canonical User ID - device_id (str): the device ID to associate with the access token + device_id (str|None): the device ID to associate with the tokens. + None to leave the tokens unassociated with a device (deprecated: + we should always have a device ID) + initial_display_name (str): display name to associate with the + device if it needs re-registering Returns: A tuple of: The access token for the user's session. @@ -397,6 +405,16 @@ class AuthHandler(BaseHandler): logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) refresh_token = yield self.issue_refresh_token(user_id, device_id) + + # the device *should* have been registered before we got here; however, + # it's possible we raced against a DELETE operation. The thing we + # really don't want is active access_tokens without a record of the + # device, so we double-check it here. + if device_id is not None: + yield self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 1f9e15c33c..a7a192e1c9 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -100,7 +100,7 @@ class DeviceHandler(BaseHandler): Args: user_id (str): - device_id (str) + device_id (str): Returns: defer.Deferred: dict[str, X]: info on the device @@ -117,6 +117,31 @@ class DeviceHandler(BaseHandler): _update_device_from_client_ips(device, ips) defer.returnValue(device) + @defer.inlineCallbacks + def delete_device(self, user_id, device_id): + """ Delete the given device + + Args: + user_id (str): + device_id (str): + + Returns: + defer.Deferred: + """ + + try: + yield self.store.delete_device(user_id, device_id) + except errors.StoreError, e: + if e.code == 404: + # no match + pass + else: + raise + + yield self.store.user_delete_access_tokens(user_id, + device_id=device_id) + + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index e8b791519c..92fcae674a 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -152,7 +152,10 @@ class LoginRestServlet(ClientV1RestServlet): ) device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { "user_id": user_id, # may have changed @@ -173,7 +176,10 @@ class LoginRestServlet(ClientV1RestServlet): ) device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { "user_id": user_id, # may have changed @@ -262,7 +268,8 @@ class LoginRestServlet(ClientV1RestServlet): ) access_token, refresh_token = ( yield auth_handler.get_login_tuple_for_user_id( - registered_user_id, device_id + registered_user_id, device_id, + login_submission.get("initial_device_display_name") ) ) result = { diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 8b9ab4f674..30ef8b3da9 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -70,6 +70,20 @@ class DeviceRestServlet(RestServlet): ) defer.returnValue((200, device)) + @defer.inlineCallbacks + def on_DELETE(self, request, device_id): + # XXX: it's not completely obvious we want to expose this endpoint. + # It allows the client to delete access tokens, which feels like a + # thing which merits extra auth. But if we want to do the interactive- + # auth dance, we should really make it possible to delete more than one + # device at a time. + requester = yield self.auth.get_user_by_req(request) + yield self.device_handler.delete_device( + requester.user.to_string(), + device_id, + ) + defer.returnValue((200, {})) + def register_servlets(hs, http_server): DevicesRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index c8c9395fc6..9f599ea8bb 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -374,13 +374,13 @@ class RegisterRestServlet(RestServlet): """ device_id = yield self._register_device(user_id, params) - access_token = yield self.auth_handler.issue_access_token( - user_id, device_id=device_id + access_token, refresh_token = ( + yield self.auth_handler.get_login_tuple_for_user_id( + user_id, device_id=device_id, + initial_display_name=params.get("initial_device_display_name") + ) ) - refresh_token = yield self.auth_handler.issue_refresh_token( - user_id, device_id=device_id - ) defer.returnValue({ "user_id": user_id, "access_token": access_token, diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 1cc6e07f2b..4689980f80 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -76,6 +76,21 @@ class DeviceStore(SQLBaseStore): desc="get_device", ) + def delete_device(self, user_id, device_id): + """Delete a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred + """ + return self._simple_delete_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="delete_device", + ) + @defer.inlineCallbacks def get_devices_by_user(self, user_id): """Retrieve all of a user's registered devices. diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 9a92b35361..935e82bf7a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -18,18 +18,31 @@ import re from twisted.internet import defer from synapse.api.errors import StoreError, Codes - -from ._base import SQLBaseStore +from synapse.storage import background_updates from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -class RegistrationStore(SQLBaseStore): +class RegistrationStore(background_updates.BackgroundUpdateStore): def __init__(self, hs): super(RegistrationStore, self).__init__(hs) self.clock = hs.get_clock() + self.register_background_index_update( + "access_tokens_device_index", + index_name="access_tokens_device_id", + table="access_tokens", + columns=["user_id", "device_id"], + ) + + self.register_background_index_update( + "refresh_tokens_device_index", + index_name="refresh_tokens_device_id", + table="refresh_tokens", + columns=["user_id", "device_id"], + ) + @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. @@ -238,11 +251,16 @@ class RegistrationStore(SQLBaseStore): self.get_user_by_id.invalidate((user_id,)) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id, except_token_ids=[]): + def user_delete_access_tokens(self, user_id, except_token_ids=[], + device_id=None): def f(txn): sql = "SELECT token FROM access_tokens WHERE user_id = ?" clauses = [user_id] + if device_id is not None: + sql += " AND device_id = ?" + clauses.append(device_id) + if except_token_ids: sql += " AND id NOT IN (%s)" % ( ",".join(["?" for _ in except_token_ids]), diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/schema/delta/33/access_tokens_device_index.sql new file mode 100644 index 0000000000..61ad3fe3e8 --- /dev/null +++ b/synapse/storage/schema/delta/33/access_tokens_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('access_tokens_device_index', '{}'); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql new file mode 100644 index 0000000000..bb225dafbf --- /dev/null +++ b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('refresh_tokens_device_index', '{}'); diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 331aa13fed..214e722eb3 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -12,11 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse import types + from twisted.internet import defer +import synapse.api.errors import synapse.handlers.device + import synapse.storage +from synapse import types from tests import unittest, utils user1 = "@boris:aaa" @@ -27,7 +30,7 @@ class DeviceTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(DeviceTestCase, self).__init__(*args, **kwargs) self.store = None # type: synapse.storage.DataStore - self.handler = None # type: device.DeviceHandler + self.handler = None # type: synapse.handlers.device.DeviceHandler self.clock = None # type: utils.MockClock @defer.inlineCallbacks @@ -123,6 +126,21 @@ class DeviceTestCase(unittest.TestCase): "last_seen_ts": 3000000, }, res) + @defer.inlineCallbacks + def test_delete_device(self): + yield self._record_users() + + # delete the device + yield self.handler.delete_device(user1, "abc") + + # check the device was deleted + with self.assertRaises(synapse.api.errors.NotFoundError): + yield self.handler.get_device(user1, "abc") + + # we'd like to check the access token was invalidated, but that's a + # bit of a PITA. + + @defer.inlineCallbacks def _record_users(self): # check this works for both devices which have a recorded client_ip, diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 3bd7065e32..8ac56a1fb2 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -65,13 +65,16 @@ class RegisterRestServletTestCase(unittest.TestCase): self.registration_handler.appservice_register = Mock( return_value=user_id ) - self.auth_handler.issue_access_token = Mock(return_value=token) + self.auth_handler.get_login_tuple_for_user_id = Mock( + return_value=(token, "kermits_refresh_token") + ) (code, result) = yield self.servlet.on_POST(self.request) self.assertEquals(code, 200) det_data = { "user_id": user_id, "access_token": token, + "refresh_token": "kermits_refresh_token", "home_server": self.hs.hostname } self.assertDictContainsSubset(det_data, result) @@ -121,7 +124,9 @@ class RegisterRestServletTestCase(unittest.TestCase): "password": "monkey" }, None) self.registration_handler.register = Mock(return_value=(user_id, None)) - self.auth_handler.issue_access_token = Mock(return_value=token) + self.auth_handler.get_login_tuple_for_user_id = Mock( + return_value=(token, "kermits_refresh_token") + ) self.device_handler.check_device_registered = \ Mock(return_value=device_id) @@ -130,13 +135,14 @@ class RegisterRestServletTestCase(unittest.TestCase): det_data = { "user_id": user_id, "access_token": token, + "refresh_token": "kermits_refresh_token", "home_server": self.hs.hostname, "device_id": device_id, } self.assertDictContainsSubset(det_data, result) self.assertIn("refresh_token", result) - self.auth_handler.issue_access_token.assert_called_once_with( - user_id, device_id=device_id) + self.auth_handler.get_login_tuple_for_user_id( + user_id, device_id=device_id, initial_device_display_name=None) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False -- cgit 1.4.1 From 012b4c19132d57fdbc1b6b0e304eb60eaf19200f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 25 Jul 2016 17:51:24 +0100 Subject: Implement updating devices You can update the displayname of devices now. --- synapse/handlers/device.py | 24 ++++++++++++++++++++++ synapse/rest/client/v2_alpha/devices.py | 24 +++++++++++++++------- synapse/storage/devices.py | 27 ++++++++++++++++++++++++- tests/handlers/test_device.py | 16 +++++++++++++++ tests/storage/test_devices.py | 36 +++++++++++++++++++++++++++++++++ 5 files changed, 119 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index a7a192e1c9..9e65d85e6d 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -141,6 +141,30 @@ class DeviceHandler(BaseHandler): yield self.store.user_delete_access_tokens(user_id, device_id=device_id) + @defer.inlineCallbacks + def update_device(self, user_id, device_id, content): + """ Update the given device + + Args: + user_id (str): + device_id (str): + content (dict): body of update request + + Returns: + defer.Deferred: + """ + + try: + yield self.store.update_device( + user_id, + device_id, + new_display_name=content.get("display_name") + ) + except errors.StoreError, e: + if e.code == 404: + raise errors.NotFoundError() + else: + raise def _update_device_from_client_ips(device, client_ips): diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 30ef8b3da9..8fbd3d3dfc 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -13,19 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +import logging -from synapse.http.servlet import RestServlet +from twisted.internet import defer +from synapse.http import servlet from ._base import client_v2_patterns -import logging - - logger = logging.getLogger(__name__) -class DevicesRestServlet(RestServlet): +class DevicesRestServlet(servlet.RestServlet): PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) def __init__(self, hs): @@ -47,7 +45,7 @@ class DevicesRestServlet(RestServlet): defer.returnValue((200, {"devices": devices})) -class DeviceRestServlet(RestServlet): +class DeviceRestServlet(servlet.RestServlet): PATTERNS = client_v2_patterns("/devices/(?P[^/]*)$", releases=[], v2_alpha=False) @@ -84,6 +82,18 @@ class DeviceRestServlet(RestServlet): ) defer.returnValue((200, {})) + @defer.inlineCallbacks + def on_PUT(self, request, device_id): + requester = yield self.auth.get_user_by_req(request) + + body = servlet.parse_json_object_from_request(request) + yield self.device_handler.update_device( + requester.user.to_string(), + device_id, + body + ) + defer.returnValue((200, {})) + def register_servlets(hs, http_server): DevicesRestServlet(hs).register(http_server) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 4689980f80..afd6530cab 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -81,7 +81,7 @@ class DeviceStore(SQLBaseStore): Args: user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to retrieve + device_id (str): The ID of the device to delete Returns: defer.Deferred """ @@ -91,6 +91,31 @@ class DeviceStore(SQLBaseStore): desc="delete_device", ) + def update_device(self, user_id, device_id, new_display_name=None): + """Update a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to update + new_display_name (str|None): new displayname for device; None + to leave unchanged + Raises: + StoreError: if the device is not found + Returns: + defer.Deferred + """ + updates = {} + if new_display_name is not None: + updates["display_name"] = new_display_name + if not updates: + return defer.succeed(None) + return self._simple_update_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + updatevalues=updates, + desc="update_device", + ) + @defer.inlineCallbacks def get_devices_by_user(self, user_id): """Retrieve all of a user's registered devices. diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 214e722eb3..85a970a6c9 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -140,6 +140,22 @@ class DeviceTestCase(unittest.TestCase): # we'd like to check the access token was invalidated, but that's a # bit of a PITA. + @defer.inlineCallbacks + def test_update_device(self): + yield self._record_users() + + update = {"display_name": "new display"} + yield self.handler.update_device(user1, "abc", update) + + res = yield self.handler.get_device(user1, "abc") + self.assertEqual(res["display_name"], "new display") + + @defer.inlineCallbacks + def test_update_unknown_device(self): + update = {"display_name": "new_display"} + with self.assertRaises(synapse.api.errors.NotFoundError): + yield self.handler.update_device("user_id", "unknown_device_id", + update) @defer.inlineCallbacks def _record_users(self): diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index a6ce993375..f8725acea0 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -15,6 +15,7 @@ from twisted.internet import defer +import synapse.api.errors import tests.unittest import tests.utils @@ -67,3 +68,38 @@ class DeviceStoreTestCase(tests.unittest.TestCase): "device_id": "device2", "display_name": "display_name 2", }, res["device2"]) + + @defer.inlineCallbacks + def test_update_device(self): + yield self.store.store_device( + "user_id", "device_id", "display_name 1" + ) + + res = yield self.store.get_device("user_id", "device_id") + self.assertEqual("display_name 1", res["display_name"]) + + # do a no-op first + yield self.store.update_device( + "user_id", "device_id", + ) + res = yield self.store.get_device("user_id", "device_id") + self.assertEqual("display_name 1", res["display_name"]) + + # do the update + yield self.store.update_device( + "user_id", "device_id", + new_display_name="display_name 2", + ) + + # check it worked + res = yield self.store.get_device("user_id", "device_id") + self.assertEqual("display_name 2", res["display_name"]) + + @defer.inlineCallbacks + def test_update_unknown_device(self): + with self.assertRaises(synapse.api.errors.StoreError) as cm: + yield self.store.update_device( + "user_id", "unknown_device_id", + new_display_name="display_name 2", + ) + self.assertEqual(404, cm.exception.code) -- cgit 1.4.1 From 8e0249416643f20f0c4cd8f2e19cf45ea63289d3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 26 Jul 2016 11:09:47 +0100 Subject: Delete refresh tokens when deleting devices --- synapse/handlers/device.py | 6 ++-- synapse/storage/registration.py | 58 +++++++++++++++++++++++++++++--------- tests/storage/test_registration.py | 34 ++++++++++++++++++++++ 3 files changed, 83 insertions(+), 15 deletions(-) (limited to 'tests') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 9e65d85e6d..eaead50800 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -138,8 +138,10 @@ class DeviceHandler(BaseHandler): else: raise - yield self.store.user_delete_access_tokens(user_id, - device_id=device_id) + yield self.store.user_delete_access_tokens( + user_id, device_id=device_id, + delete_refresh_tokens=True, + ) @defer.inlineCallbacks def update_device(self, user_id, device_id, content): diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 935e82bf7a..d9555e073a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -252,20 +252,36 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): @defer.inlineCallbacks def user_delete_access_tokens(self, user_id, except_token_ids=[], - device_id=None): - def f(txn): - sql = "SELECT token FROM access_tokens WHERE user_id = ?" + device_id=None, + delete_refresh_tokens=False): + """ + Invalidate access/refresh tokens belonging to a user + + Args: + user_id (str): ID of user the tokens belong to + except_token_ids (list[str]): list of access_tokens which should + *not* be deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + delete_refresh_tokens (bool): True to delete refresh tokens as + well as access tokens. + Returns: + defer.Deferred: + """ + def f(txn, table, except_tokens, call_after_delete): + sql = "SELECT token FROM %s WHERE user_id = ?" % table clauses = [user_id] if device_id is not None: sql += " AND device_id = ?" clauses.append(device_id) - if except_token_ids: + if except_tokens: sql += " AND id NOT IN (%s)" % ( - ",".join(["?" for _ in except_token_ids]), + ",".join(["?" for _ in except_tokens]), ) - clauses += except_token_ids + clauses += except_tokens txn.execute(sql, clauses) @@ -274,16 +290,33 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): n = 100 chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] for chunk in chunks: - for row in chunk: - txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) + if call_after_delete: + for row in chunk: + txn.call_after(call_after_delete, (row[0],)) txn.execute( - "DELETE FROM access_tokens WHERE token in (%s)" % ( + "DELETE FROM %s WHERE token in (%s)" % ( + table, ",".join(["?" for _ in chunk]), ), [r[0] for r in chunk] ) - yield self.runInteraction("user_delete_access_tokens", f) + # delete refresh tokens first, to stop new access tokens being + # allocated while our backs are turned + if delete_refresh_tokens: + yield self.runInteraction( + "user_delete_access_tokens", f, + table="refresh_tokens", + except_tokens=[], + call_after_delete=None, + ) + + yield self.runInteraction( + "user_delete_access_tokens", f, + table="access_tokens", + except_tokens=except_token_ids, + call_after_delete=self.get_user_by_access_token.invalidate, + ) def delete_access_token(self, access_token): def f(txn): @@ -306,9 +339,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): Args: token (str): The access token of a user. Returns: - dict: Including the name (user_id) and the ID of their access token. - Raises: - StoreError if no user was found. + defer.Deferred: None, if the token did not match, ootherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`. """ return self.runInteraction( "get_user_by_access_token", diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index b03ca303a2..f7d74dea8e 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -128,6 +128,40 @@ class RegistrationStoreTestCase(unittest.TestCase): with self.assertRaises(StoreError): yield self.store.exchange_refresh_token(last_token, generator.generate) + @defer.inlineCallbacks + def test_user_delete_access_tokens(self): + # add some tokens + generator = TokenGenerator() + refresh_token = generator.generate(self.user_id) + yield self.store.register(self.user_id, self.tokens[0], self.pwhash) + yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], + self.device_id) + yield self.store.add_refresh_token_to_user(self.user_id, refresh_token, + self.device_id) + + # now delete some + yield self.store.user_delete_access_tokens( + self.user_id, device_id=self.device_id, delete_refresh_tokens=True) + + # check they were deleted + user = yield self.store.get_user_by_access_token(self.tokens[1]) + self.assertIsNone(user, "access token was not deleted by device_id") + with self.assertRaises(StoreError): + yield self.store.exchange_refresh_token(refresh_token, + generator.generate) + + # check the one not associated with the device was not deleted + user = yield self.store.get_user_by_access_token(self.tokens[0]) + self.assertEqual(self.user_id, user["name"]) + + # now delete the rest + yield self.store.user_delete_access_tokens( + self.user_id, delete_refresh_tokens=True) + + user = yield self.store.get_user_by_access_token(self.tokens[0]) + self.assertIsNone(user, + "access token was not deleted without device_id") + class TokenGenerator: def __init__(self): -- cgit 1.4.1 From eb359eced44407b1ee9648f10fdf3df63c8d40ad Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 26 Jul 2016 16:46:53 +0100 Subject: Add `create_requester` function Wrap the `Requester` constructor with a function which provides sensible defaults, and use it throughout --- synapse/api/auth.py | 24 +++++++++++------------- synapse/handlers/_base.py | 13 +++++++------ synapse/handlers/profile.py | 12 +++++++----- synapse/handlers/register.py | 16 +++++++++------- synapse/handlers/room_member.py | 20 +++++++++----------- synapse/rest/client/v2_alpha/keys.py | 10 ++++------ synapse/types.py | 33 ++++++++++++++++++++++++++++++++- tests/handlers/test_profile.py | 10 ++++++---- tests/replication/test_resource.py | 20 +++++++++++--------- tests/rest/client/v1/test_profile.py | 13 +++++-------- tests/utils.py | 5 ----- 11 files changed, 101 insertions(+), 75 deletions(-) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index eca8513905..eecf3b0b2a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -13,22 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + +import pymacaroons from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json, SignatureVerifyException - from twisted.internet import defer +from unpaddedbase64 import decode_base64 +import synapse.types from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError -from synapse.types import Requester, UserID, get_domain_from_id -from synapse.util.logutils import log_function +from synapse.types import UserID, get_domain_from_id from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.logutils import log_function from synapse.util.metrics import Measure -from unpaddedbase64 import decode_base64 - -import logging -import pymacaroons logger = logging.getLogger(__name__) @@ -566,8 +566,7 @@ class Auth(object): Args: request - An HTTP request with an access_token query parameter. Returns: - defer.Deferred: resolves to a namedtuple including "user" (UserID) - "access_token_id" (int), "is_guest" (bool) + defer.Deferred: resolves to a ``synapse.types.Requester`` object Raises: AuthError if no user by that token exists or the token is invalid. """ @@ -576,9 +575,7 @@ class Auth(object): user_id = yield self._get_appservice_user_id(request.args) if user_id: request.authenticated_entity = user_id - defer.returnValue( - Requester(UserID.from_string(user_id), "", False) - ) + defer.returnValue(synapse.types.create_requester(user_id)) access_token = request.args["access_token"][0] user_info = yield self.get_user_by_access_token(access_token, rights) @@ -612,7 +609,8 @@ class Auth(object): request.authenticated_entity = user.to_string() - defer.returnValue(Requester(user, token_id, is_guest)) + defer.returnValue(synapse.types.create_requester( + user, token_id, is_guest, device_id)) except KeyError: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 6264aa0d9a..11081a0cd5 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -from synapse.api.errors import LimitExceededError +import synapse.types from synapse.api.constants import Membership, EventTypes -from synapse.types import UserID, Requester - - -import logging +from synapse.api.errors import LimitExceededError +from synapse.types import UserID logger = logging.getLogger(__name__) @@ -124,7 +124,8 @@ class BaseHandler(object): # and having homeservers have their own users leave keeps more # of that decision-making and control local to the guest-having # homeserver. - requester = Requester(target_user, "", True) + requester = synapse.types.create_requester( + target_user, is_guest=True) handler = self.hs.get_handlers().room_member_handler yield handler.update_membership( requester, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 711a6a567f..d9ac09078d 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -13,15 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer +import synapse.types from synapse.api.errors import SynapseError, AuthError, CodeMessageException -from synapse.types import UserID, Requester - +from synapse.types import UserID from ._base import BaseHandler -import logging - logger = logging.getLogger(__name__) @@ -165,7 +165,9 @@ class ProfileHandler(BaseHandler): try: # Assume the user isn't a guest because we don't let guests set # profile or avatar data. - requester = Requester(user, "", False) + # XXX why are we recreating `requester` here for each room? + # what was wrong with the `requester` we were passed? + requester = synapse.types.create_requester(user) yield handler.update_membership( requester, user, diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 94b19d0cb0..b9b5880d64 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -14,18 +14,19 @@ # limitations under the License. """Contains functions for registering clients.""" +import logging +import urllib + from twisted.internet import defer -from synapse.types import UserID, Requester +import synapse.types from synapse.api.errors import ( AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError ) -from ._base import BaseHandler -from synapse.util.async import run_on_reactor from synapse.http.client import CaptchaServerHttpClient - -import logging -import urllib +from synapse.types import UserID +from synapse.util.async import run_on_reactor +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -410,8 +411,9 @@ class RegistrationHandler(BaseHandler): if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) profile_handler = self.hs.get_handlers().profile_handler + requester = synapse.types.create_requester(user) yield profile_handler.set_displayname( - user, Requester(user, token, False), displayname + user, requester, displayname ) defer.returnValue((user_id, token)) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7e616f44fd..8cec8fc4ed 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -14,24 +14,22 @@ # limitations under the License. -from twisted.internet import defer +import logging -from ._base import BaseHandler +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json +from twisted.internet import defer +from unpaddedbase64 import decode_base64 -from synapse.types import UserID, RoomID, Requester +import synapse.types from synapse.api.constants import ( EventTypes, Membership, ) from synapse.api.errors import AuthError, SynapseError, Codes +from synapse.types import UserID, RoomID from synapse.util.async import Linearizer from synapse.util.distributor import user_left_room, user_joined_room - -from signedjson.sign import verify_signed_json -from signedjson.key import decode_verify_key_bytes - -from unpaddedbase64 import decode_base64 - -import logging +from ._base import BaseHandler logger = logging.getLogger(__name__) @@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler): ) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) else: - requester = Requester(target_user, None, False) + requester = synapse.types.create_requester(target_user) message_handler = self.hs.get_handlers().message_handler prev_event = message_handler.deduplicate_state_event(event, context) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 89ab39491c..56364af337 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -13,18 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + +import simplejson as json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID - -from canonicaljson import encode_canonical_json - from ._base import client_v2_patterns -import logging -import simplejson as json - logger = logging.getLogger(__name__) diff --git a/synapse/types.py b/synapse/types.py index f639651a73..5349b0c450 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -18,7 +18,38 @@ from synapse.api.errors import SynapseError from collections import namedtuple -Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"]) +Requester = namedtuple("Requester", + ["user", "access_token_id", "is_guest", "device_id"]) +""" +Represents the user making a request + +Attributes: + user (UserID): id of the user making the request + access_token_id (int|None): *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest (bool): True if the user making this request is a guest user + device_id (str|None): device_id which was set at authentication time +""" + + +def create_requester(user_id, access_token_id=None, is_guest=False, + device_id=None): + """ + Create a new ``Requester`` object + + Args: + user_id (str|UserID): id of the user making the request + access_token_id (int|None): *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest (bool): True if the user making this request is a guest user + device_id (str|None): device_id which was set at authentication time + + Returns: + Requester + """ + if not isinstance(user_id, UserID): + user_id = UserID.from_string(user_id) + return Requester(user_id, access_token_id, is_guest, device_id) def get_domain_from_id(string): diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 4f2c14e4ff..f1f664275f 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -19,11 +19,12 @@ from twisted.internet import defer from mock import Mock, NonCallableMock +import synapse.types from synapse.api.errors import AuthError from synapse.handlers.profile import ProfileHandler from synapse.types import UserID -from tests.utils import setup_test_homeserver, requester_for_user +from tests.utils import setup_test_homeserver class ProfileHandlers(object): @@ -86,7 +87,7 @@ class ProfileTestCase(unittest.TestCase): def test_set_my_name(self): yield self.handler.set_displayname( self.frank, - requester_for_user(self.frank), + synapse.types.create_requester(self.frank), "Frank Jr." ) @@ -99,7 +100,7 @@ class ProfileTestCase(unittest.TestCase): def test_set_my_name_noauth(self): d = self.handler.set_displayname( self.frank, - requester_for_user(self.bob), + synapse.types.create_requester(self.bob), "Frank Jr." ) @@ -144,7 +145,8 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar(self): yield self.handler.set_avatar_url( - self.frank, requester_for_user(self.frank), "http://my.server/pic.gif" + self.frank, synapse.types.create_requester(self.frank), + "http://my.server/pic.gif" ) self.assertEquals( diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index 842e3d29d7..e70ac6f14d 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.resource import ReplicationResource -from synapse.types import Requester, UserID +import contextlib +import json +from mock import Mock, NonCallableMock from twisted.internet import defer + +import synapse.types +from synapse.replication.resource import ReplicationResource +from synapse.types import UserID from tests import unittest -from tests.utils import setup_test_homeserver, requester_for_user -from mock import Mock, NonCallableMock -import json -import contextlib +from tests.utils import setup_test_homeserver class ReplicationResourceCase(unittest.TestCase): @@ -61,7 +63,7 @@ class ReplicationResourceCase(unittest.TestCase): def test_events_and_state(self): get = self.get(events="-1", state="-1", timeout="0") yield self.hs.get_handlers().room_creation_handler.create_room( - Requester(self.user, "", False), {} + synapse.types.create_requester(self.user), {} ) code, body = yield get self.assertEquals(code, 200) @@ -144,7 +146,7 @@ class ReplicationResourceCase(unittest.TestCase): def send_text_message(self, room_id, message): handler = self.hs.get_handlers().message_handler event = yield handler.create_and_send_nonmember_event( - requester_for_user(self.user), + synapse.types.create_requester(self.user), { "type": "m.room.message", "content": {"body": "message", "msgtype": "m.text"}, @@ -157,7 +159,7 @@ class ReplicationResourceCase(unittest.TestCase): @defer.inlineCallbacks def create_room(self): result = yield self.hs.get_handlers().room_creation_handler.create_room( - Requester(self.user, "", False), {} + synapse.types.create_requester(self.user), {} ) defer.returnValue(result["room_id"]) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index af02fce8fb..1e95e97538 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -14,17 +14,14 @@ # limitations under the License. """Tests REST events for /profile paths.""" -from tests import unittest -from twisted.internet import defer - from mock import Mock +from twisted.internet import defer -from ....utils import MockHttpResource, setup_test_homeserver - +import synapse.types from synapse.api.errors import SynapseError, AuthError -from synapse.types import Requester, UserID - from synapse.rest.client.v1 import profile +from tests import unittest +from ....utils import MockHttpResource, setup_test_homeserver myid = "@1234ABCD:test" PATH_PREFIX = "/_matrix/client/api/v1" @@ -52,7 +49,7 @@ class ProfileTestCase(unittest.TestCase): ) def _get_user_by_req(request=None, allow_guest=False): - return Requester(UserID.from_string(myid), "", False) + return synapse.types.create_requester(myid) hs.get_v1auth().get_user_by_req = _get_user_by_req diff --git a/tests/utils.py b/tests/utils.py index ed547bc39b..915b934e94 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,7 +20,6 @@ from synapse.storage.prepare_database import prepare_database from synapse.storage.engines import create_engine from synapse.server import HomeServer from synapse.federation.transport import server -from synapse.types import Requester from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.logcontext import LoggingContext @@ -512,7 +511,3 @@ class DeferredMockCallable(object): "call(%s)" % _format_call(c[0], c[1]) for c in calls ]) ) - - -def requester_for_user(user): - return Requester(user, None, False) -- cgit 1.4.1 From 0a7d3cd00f8b7e3ad0ba458c3ab9b40a2496545b Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 28 Jul 2016 20:24:24 +0100 Subject: Create separate methods for getting messages to push for the email and http pushers rather than trying to make a single method that will work with their conflicting requirements. The http pusher needs to get the messages in ascending stream order, and doesn't want to miss a message. The email pusher needs to get the messages in descending timestamp order, and doesn't mind if it misses messages. --- synapse/push/emailpusher.py | 5 +- synapse/push/httppusher.py | 3 +- synapse/replication/slave/storage/events.py | 7 +- synapse/storage/event_push_actions.py | 199 +++++++++++++++++++++------- tests/storage/test_event_push_actions.py | 41 ++++++ 5 files changed, 204 insertions(+), 51 deletions(-) create mode 100644 tests/storage/test_event_push_actions.py (limited to 'tests') diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 12a3ec7fd8..e224b68291 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -140,9 +140,8 @@ class EmailPusher(object): being run. """ start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering - unprocessed = yield self.store.get_unread_push_actions_for_user_in_range( - self.user_id, start, self.max_stream_ordering - ) + fn = self.store.get_unread_push_actions_for_user_in_range_for_email + unprocessed = yield fn(self.user_id, start, self.max_stream_ordering) soonest_due_at = None diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 2acc6cc214..9a7db61220 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -141,7 +141,8 @@ class HttpPusher(object): run once per pusher. """ - unprocessed = yield self.store.get_unread_push_actions_for_user_in_range( + fn = self.store.get_unread_push_actions_for_user_in_range_for_http + unprocessed = yield fn( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 369d839464..6a644f1386 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -93,8 +93,11 @@ class SlavedEventStore(BaseSlavedStore): StreamStore.__dict__["get_recent_event_ids_for_room"] ) - get_unread_push_actions_for_user_in_range = ( - DataStore.get_unread_push_actions_for_user_in_range.__func__ + get_unread_push_actions_for_user_in_range_for_http = ( + DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ + ) + get_unread_push_actions_for_user_in_range_for_email = ( + DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__ ) get_push_action_users_in_range = ( DataStore.get_push_action_users_in_range.__func__ diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 958dbcc22b..5ab362bef2 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -117,40 +117,149 @@ class EventPushActionsStore(SQLBaseStore): defer.returnValue(ret) @defer.inlineCallbacks - def get_unread_push_actions_for_user_in_range(self, user_id, - min_stream_ordering, - max_stream_ordering, - limit=20): + def get_unread_push_actions_for_user_in_range_for_http( + self, user_id, min_stream_ordering, max_stream_ordering, limit=20 + ): """Get a list of the most recent unread push actions for a given user, - within the given stream ordering range. + within the given stream ordering range. Called by the httppusher. Args: - user_id (str) - min_stream_ordering - max_stream_ordering - limit (int) + user_id (str): The user to fetch push actions for. + min_stream_ordering(int): The exclusive lower bound on the + stream ordering of event push actions to fetch. + max_stream_ordering(int): The inclusive upper bound on the + stream ordering of event push actions to fetch. + limit (int): The maximum number of rows to return. + Returns: + A promise which resolves to a list of dicts with the keys "event_id", + "room_id", "stream_ordering", "actions". + The list will be ordered by ascending stream_ordering. + The list will have between 0~limit entries. + """ + # find rooms that have a read receipt in them and return the next + # push actions + def get_after_receipt(txn): + # find rooms that have a read receipt in them and return the next + # push actions + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions" + " FROM (" + " SELECT room_id," + " MAX(topological_ordering) as topological_ordering," + " MAX(stream_ordering) as stream_ordering" + " FROM events" + " INNER JOIN receipts_linearized USING (room_id, event_id)" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + ") AS rl," + " event_push_actions AS ep" + " WHERE" + " ep.room_id = rl.room_id" + " AND (" + " ep.topological_ordering > rl.topological_ordering" + " OR (" + " ep.topological_ordering = rl.topological_ordering" + " AND ep.stream_ordering > rl.stream_ordering" + " )" + " )" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering ASC LIMIT ?" + ) + args = [ + user_id, user_id, + min_stream_ordering, max_stream_ordering, limit, + ] + txn.execute(sql, args) + return txn.fetchall() + after_read_receipt = yield self.runInteraction( + "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt + ) + + # There are rooms with push actions in them but you don't have a read receipt in + # them e.g. rooms you've been invited to, so get push actions for rooms which do + # not have read receipts in them too. + def get_no_receipt(txn): + sql = ( + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " e.received_ts" + " FROM event_push_actions AS ep" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE" + " ep.room_id NOT IN (" + " SELECT room_id FROM receipts_linearized" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + " )" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering ASC LIMIT ?" + ) + args = [ + user_id, user_id, + min_stream_ordering, max_stream_ordering, limit, + ] + txn.execute(sql, args) + return txn.fetchall() + no_read_receipt = yield self.runInteraction( + "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt + ) + + notifs = [ + { + "event_id": row[0], + "room_id": row[1], + "stream_ordering": row[2], + "actions": json.loads(row[3]), + } for row in after_read_receipt + no_read_receipt + ] + + # Now sort it so it's ordered correctly, since currently it will + # contain results from the first query, correctly ordered, followed + # by results from the second query, but we want them all ordered + # by stream_ordering, oldest first. + notifs.sort(key=lambda r: r['stream_ordering']) + + # Take only up to the limit. We have to stop at the limit because + # one of the subqueries may have hit the limit. + defer.returnValue(notifs[:limit]) + + @defer.inlineCallbacks + def get_unread_push_actions_for_user_in_range_for_email( + self, user_id, min_stream_ordering, max_stream_ordering, limit=20 + ): + """Get a list of the most recent unread push actions for a given user, + within the given stream ordering range. Called by the emailpusher + + Args: + user_id (str): The user to fetch push actions for. + min_stream_ordering(int): The exclusive lower bound on the + stream ordering of event push actions to fetch. + max_stream_ordering(int): The inclusive upper bound on the + stream ordering of event push actions to fetch. + limit (int): The maximum number of rows to return. Returns: A promise which resolves to a list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts". + The list will be ordered by descending received_ts. The list will have between 0~limit entries. """ # find rooms that have a read receipt in them and return the most recent # push actions def get_after_receipt(txn): - # XXX: Do we really need to GROUP BY user_id on the inner SELECT? - # XXX: NATURAL JOIN obfuscates which columns are being joined on the - # inner SELECT (the room_id and event_id), can we - # INNER JOIN ... USING instead? sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, " - "e.received_ts " - "FROM (" - " SELECT room_id, user_id, " - " max(topological_ordering) as topological_ordering, " - " max(stream_ordering) as stream_ordering " - " FROM events" - " NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'" - " GROUP BY room_id, user_id" + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " e.received_ts" + " FROM (" + " SELECT room_id," + " MAX(topological_ordering) as topological_ordering," + " MAX(stream_ordering) as stream_ordering" + " FROM events" + " INNER JOIN receipts_linearized USING (room_id, event_id)" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" ") AS rl," " event_push_actions AS ep" " INNER JOIN events AS e USING (room_id, event_id)" @@ -165,47 +274,47 @@ class EventPushActionsStore(SQLBaseStore): " )" " AND ep.stream_ordering > ?" " AND ep.user_id = ?" - " AND ep.user_id = rl.user_id" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering DESC LIMIT ?" ) - args = [min_stream_ordering, user_id] - if max_stream_ordering is not None: - sql += " AND ep.stream_ordering <= ?" - args.append(max_stream_ordering) - sql += " ORDER BY ep.stream_ordering DESC LIMIT ?" - args.append(limit) + args = [ + user_id, user_id, + min_stream_ordering, max_stream_ordering, limit, + ] txn.execute(sql, args) return txn.fetchall() after_read_receipt = yield self.runInteraction( - "get_unread_push_actions_for_user_in_range", get_after_receipt + "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt ) # There are rooms with push actions in them but you don't have a read receipt in # them e.g. rooms you've been invited to, so get push actions for rooms which do # not have read receipts in them too. def get_no_receipt(txn): - # XXX: Does the inner SELECT really need to select from the events table? - # We're just extracting the room_id, so isn't receipts_linearized enough? sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," " e.received_ts" " FROM event_push_actions AS ep" - " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" - " WHERE ep.room_id not in (" - " SELECT room_id FROM events NATURAL JOIN receipts_linearized" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - ") AND ep.user_id = ? AND ep.stream_ordering > ?" + " INNER JOIN events AS e USING (room_id, event_id)" + " WHERE" + " ep.room_id NOT IN (" + " SELECT room_id FROM receipts_linearized" + " WHERE receipt_type = 'm.read' AND user_id = ?" + " GROUP BY room_id" + " )" + " AND ep.user_id = ?" + " AND ep.stream_ordering > ?" + " AND ep.stream_ordering <= ?" + " ORDER BY ep.stream_ordering DESC LIMIT ?" ) - args = [user_id, user_id, min_stream_ordering] - if max_stream_ordering is not None: - sql += " AND ep.stream_ordering <= ?" - args.append(max_stream_ordering) - sql += " ORDER BY ep.stream_ordering DESC LIMIT ?" - args.append(limit) + args = [ + user_id, user_id, + min_stream_ordering, max_stream_ordering, limit, + ] txn.execute(sql, args) return txn.fetchall() no_read_receipt = yield self.runInteraction( - "get_unread_push_actions_for_user_in_range", get_no_receipt + "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt ) # Make a list of dicts from the two sets of results. diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py new file mode 100644 index 0000000000..e9044afa2e --- /dev/null +++ b/tests/storage/test_event_push_actions.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import tests.unittest +import tests.utils + +USER_ID = "@user:example.com" + + +class EventPushActionsStoreTestCase(tests.unittest.TestCase): + + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_get_unread_push_actions_for_user_in_range_for_http(self): + yield self.store.get_unread_push_actions_for_user_in_range_for_http( + USER_ID, 0, 1000, 20 + ) + + @defer.inlineCallbacks + def test_get_unread_push_actions_for_user_in_range_for_email(self): + yield self.store.get_unread_push_actions_for_user_in_range_for_email( + USER_ID, 0, 1000, 20 + ) -- cgit 1.4.1 From 91fa69e0299167dad4df9831b8d175a99564b266 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 3 Aug 2016 11:48:32 +0100 Subject: keys/query: return all users which were asked for In the situation where all of a user's devices get deleted, we want to indicate this to a client, so we want to return an empty dictionary, rather than nothing at all. --- synapse/handlers/e2e_keys.py | 9 +++++--- tests/handlers/test_e2e_keys.py | 46 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 tests/handlers/test_e2e_keys.py (limited to 'tests') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 1312cdf5ab..950fc927b1 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -99,6 +99,7 @@ class E2eKeysHandler(object): """ local_query = [] + result_dict = {} for user_id, device_ids in query.items(): if not self.is_mine_id(user_id): logger.warning("Request for keys for non-local user %s", @@ -111,15 +112,17 @@ class E2eKeysHandler(object): for device_id in device_ids: local_query.append((user_id, device_id)) + # make sure that each queried user appears in the result dict + result_dict[user_id] = {} + results = yield self.store.get_e2e_device_keys(local_query) # un-jsonify the results - json_result = collections.defaultdict(dict) for user_id, device_keys in results.items(): for device_id, json_bytes in device_keys.items(): - json_result[user_id][device_id] = json.loads(json_bytes) + result_dict[user_id][device_id] = json.loads(json_bytes) - defer.returnValue(json_result) + defer.returnValue(result_dict) @defer.inlineCallbacks def on_federation_query_client_keys(self, query_body): diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py new file mode 100644 index 0000000000..878a54dc34 --- /dev/null +++ b/tests/handlers/test_e2e_keys.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +from twisted.internet import defer + +import synapse.api.errors +import synapse.handlers.e2e_keys + +import synapse.storage +from tests import unittest, utils + + +class E2eKeysHandlerTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs) + self.hs = None # type: synapse.server.HomeServer + self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler + + @defer.inlineCallbacks + def setUp(self): + self.hs = yield utils.setup_test_homeserver( + handlers=None, + replication_layer=mock.Mock(), + ) + self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) + + @defer.inlineCallbacks + def test_query_local_devices_no_devices(self): + """If the user has no devices, we expect an empty list. + """ + local_user = "@boris:" + self.hs.hostname + res = yield self.handler.query_local_devices({local_user: None}) + self.assertDictEqual(res, {local_user: {}}) -- cgit 1.4.1 From 68264d7404214d30d32310991c89ea4a234c319a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 3 Aug 2016 07:46:57 +0100 Subject: Include device name in /keys/query response Add an 'unsigned' section which includes the device display name. --- synapse/handlers/e2e_keys.py | 11 +++-- synapse/storage/end_to_end_keys.py | 60 ++++++++++++++++------- tests/storage/test_end_to_end_keys.py | 92 +++++++++++++++++++++++++++++++++++ 3 files changed, 143 insertions(+), 20 deletions(-) create mode 100644 tests/storage/test_end_to_end_keys.py (limited to 'tests') diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 950fc927b1..bb69089b91 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -117,10 +117,15 @@ class E2eKeysHandler(object): results = yield self.store.get_e2e_device_keys(local_query) - # un-jsonify the results + # Build the result structure, un-jsonify the results, and add the + # "unsigned" section for user_id, device_keys in results.items(): - for device_id, json_bytes in device_keys.items(): - result_dict[user_id][device_id] = json.loads(json_bytes) + for device_id, device_info in device_keys.items(): + r = json.loads(device_info["key_json"]) + r["unsigned"] = { + "device_display_name": device_info["device_display_name"], + } + result_dict[user_id][device_id] = r defer.returnValue(result_dict) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 62b7790e91..5c8ed3e492 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import collections import twisted.internet.defer @@ -38,24 +39,49 @@ class EndToEndKeyStore(SQLBaseStore): query_list(list): List of pairs of user_ids and device_ids. Returns: Dict mapping from user-id to dict mapping from device_id to - key json byte strings. + dict containing "key_json", "device_display_name". """ - def _get_e2e_device_keys(txn): - result = {} - for user_id, device_id in query_list: - user_result = result.setdefault(user_id, {}) - keyvalues = {"user_id": user_id} - if device_id: - keyvalues["device_id"] = device_id - rows = self._simple_select_list_txn( - txn, table="e2e_device_keys_json", - keyvalues=keyvalues, - retcols=["device_id", "key_json"] - ) - for row in rows: - user_result[row["device_id"]] = row["key_json"] - return result - return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) + if not query_list: + return {} + + return self.runInteraction( + "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list + ) + + def _get_e2e_device_keys_txn(self, txn, query_list): + query_clauses = [] + query_params = [] + + for (user_id, device_id) in query_list: + query_clause = "k.user_id = ?" + query_params.append(user_id) + + if device_id: + query_clause += " AND k.device_id = ?" + query_params.append(device_id) + + query_clauses.append(query_clause) + + sql = ( + "SELECT k.user_id, k.device_id, " + " d.display_name AS device_display_name, " + " k.key_json" + " FROM e2e_device_keys_json k" + " LEFT JOIN devices d ON d.user_id = k.user_id" + " AND d.device_id = k.device_id" + " WHERE %s" + ) % ( + " OR ".join("("+q+")" for q in query_clauses) + ) + + txn.execute(sql, query_params) + rows = self.cursor_to_dict(txn) + + result = collections.defaultdict(dict) + for row in rows: + result[row["user_id"]][row["device_id"]] = row + + return result def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): def _add_e2e_one_time_keys(txn): diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py new file mode 100644 index 0000000000..0ebc6dafe8 --- /dev/null +++ b/tests/storage/test_end_to_end_keys.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +import synapse.api.errors +import tests.unittest +import tests.utils + + +class EndToEndKeyStoreTestCase(tests.unittest.TestCase): + def __init__(self, *args, **kwargs): + super(EndToEndKeyStoreTestCase, self).__init__(*args, **kwargs) + self.store = None # type: synapse.storage.DataStore + + @defer.inlineCallbacks + def setUp(self): + hs = yield tests.utils.setup_test_homeserver() + + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def test_key_without_device_name(self): + now = 1470174257070 + json = '{ "key": "value" }' + + yield self.store.set_e2e_device_keys( + "user", "device", now, json) + + res = yield self.store.get_e2e_device_keys((("user", "device"),)) + self.assertIn("user", res) + self.assertIn("device", res["user"]) + dev = res["user"]["device"] + self.assertDictContainsSubset({ + "key_json": json, + "device_display_name": None, + }, dev) + + @defer.inlineCallbacks + def test_get_key_with_device_name(self): + now = 1470174257070 + json = '{ "key": "value" }' + + yield self.store.set_e2e_device_keys( + "user", "device", now, json) + yield self.store.store_device( + "user", "device", "display_name" + ) + + res = yield self.store.get_e2e_device_keys((("user", "device"),)) + self.assertIn("user", res) + self.assertIn("device", res["user"]) + dev = res["user"]["device"] + self.assertDictContainsSubset({ + "key_json": json, + "device_display_name": "display_name", + }, dev) + + + @defer.inlineCallbacks + def test_multiple_devices(self): + now = 1470174257070 + + yield self.store.set_e2e_device_keys( + "user1", "device1", now, 'json11') + yield self.store.set_e2e_device_keys( + "user1", "device2", now, 'json12') + yield self.store.set_e2e_device_keys( + "user2", "device1", now, 'json21') + yield self.store.set_e2e_device_keys( + "user2", "device2", now, 'json22') + + res = yield self.store.get_e2e_device_keys((("user1", "device1"), + ("user2", "device2"))) + self.assertIn("user1", res) + self.assertIn("device1", res["user1"]) + self.assertNotIn("device2", res["user1"]) + self.assertIn("user2", res) + self.assertNotIn("device1", res["user2"]) + self.assertIn("device2", res["user2"]) -- cgit 1.4.1 From 98385888b891f544c68b749746604b593f98d729 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 3 Aug 2016 14:57:46 +0100 Subject: PEP8 --- synapse/storage/end_to_end_keys.py | 20 ++++++++++---------- tests/storage/test_end_to_end_keys.py | 2 -- 2 files changed, 10 insertions(+), 12 deletions(-) (limited to 'tests') diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 5c8ed3e492..385d607056 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -63,16 +63,16 @@ class EndToEndKeyStore(SQLBaseStore): query_clauses.append(query_clause) sql = ( - "SELECT k.user_id, k.device_id, " - " d.display_name AS device_display_name, " - " k.key_json" - " FROM e2e_device_keys_json k" - " LEFT JOIN devices d ON d.user_id = k.user_id" - " AND d.device_id = k.device_id" - " WHERE %s" - ) % ( - " OR ".join("("+q+")" for q in query_clauses) - ) + "SELECT k.user_id, k.device_id, " + " d.display_name AS device_display_name, " + " k.key_json" + " FROM e2e_device_keys_json k" + " LEFT JOIN devices d ON d.user_id = k.user_id" + " AND d.device_id = k.device_id" + " WHERE %s" + ) % ( + " OR ".join("(" + q + ")" for q in query_clauses) + ) txn.execute(sql, query_params) rows = self.cursor_to_dict(txn) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 0ebc6dafe8..453bc61438 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -15,7 +15,6 @@ from twisted.internet import defer -import synapse.api.errors import tests.unittest import tests.utils @@ -68,7 +67,6 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): "device_display_name": "display_name", }, dev) - @defer.inlineCallbacks def test_multiple_devices(self): now = 1470174257070 -- cgit 1.4.1 From e97648c4e219d56fa7b46cdd763a588e84a23f02 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 4 Aug 2016 16:08:32 +0100 Subject: Test summarization --- synapse/rest/media/v1/preview_url_resource.py | 106 ++++++++++---------- tests/test_preview.py | 139 ++++++++++++++++++++++++++ 2 files changed, 193 insertions(+), 52 deletions(-) create mode 100644 tests/test_preview.py (limited to 'tests') diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index ebd07d696b..b2e2e7e624 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -347,58 +347,7 @@ class PreviewUrlResource(Resource): re.sub(r'\s+', '\n', el.text).strip() for el in cloned_tree.iter() if el.text ) - - # Try to get a summary of between 200 and 500 words, respecting - # first paragraph and then word boundaries. - # TODO: Respect sentences? - MIN_SIZE = 200 - MAX_SIZE = 500 - - description = '' - - # Keep adding paragraphs until we get to the MIN_SIZE. - for text_node in text_nodes: - if len(description) < MIN_SIZE: - description += text_node + '\n' - else: - break - - description = description.strip() - description = re.sub(r'[\t ]+', ' ', description) - description = re.sub(r'[\t \r\n]*[\r\n]+', '\n', description) - - # If the concatenation of paragraphs to get above MIN_SIZE - # took us over MAX_SIZE, then we need to truncate mid paragraph - if len(description) > MAX_SIZE: - new_desc = "" - - # This splits the paragraph into words, but keeping the - # (preceeding) whitespace intact so we can easily concat - # words back together. - for match in re.finditer("\s*\S+", description): - word = match.group() - - # Keep adding words while the total length is less than - # MAX_SIZE. - if len(word) + len(new_desc) < MAX_SIZE: - new_desc += word - else: - # At this point the next word *will* take us over - # MAX_SIZE, but we also want to ensure that its not - # a huge word. If it is add it anyway and we'll - # truncate later. - if len(new_desc) < MIN_SIZE: - new_desc += word - break - - # Double check that we're not over the limit - if len(new_desc) > MAX_SIZE: - new_desc = new_desc[:MAX_SIZE] - - # We always add an ellipsis because at the very least - # we chopped mid paragraph. - description = new_desc.strip() + "…" - og['og:description'] = description if description else None + og['og:description'] = _summarize_paragraphs(text_nodes) # TODO: delete the url downloads to stop diskfilling, # as we only ever cared about its OG @@ -506,3 +455,56 @@ class PreviewUrlResource(Resource): content_type.startswith("application/xhtml") ): return True + + +def summarize_paragraphs(text_nodes, min_size=200, max_size=500): + # Try to get a summary of between 200 and 500 words, respecting + # first paragraph and then word boundaries. + # TODO: Respect sentences? + + description = '' + + # Keep adding paragraphs until we get to the MIN_SIZE. + for text_node in text_nodes: + if len(description) < min_size: + text_node = re.sub(r'[\t \r\n]+', ' ', text_node) + description += text_node + '\n\n' + else: + break + + description = description.strip() + description = re.sub(r'[\t ]+', ' ', description) + description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description) + + # If the concatenation of paragraphs to get above MIN_SIZE + # took us over MAX_SIZE, then we need to truncate mid paragraph + if len(description) > max_size: + new_desc = "" + + # This splits the paragraph into words, but keeping the + # (preceeding) whitespace intact so we can easily concat + # words back together. + for match in re.finditer("\s*\S+", description): + word = match.group() + + # Keep adding words while the total length is less than + # MAX_SIZE. + if len(word) + len(new_desc) < max_size: + new_desc += word + else: + # At this point the next word *will* take us over + # MAX_SIZE, but we also want to ensure that its not + # a huge word. If it is add it anyway and we'll + # truncate later. + if len(new_desc) < min_size: + new_desc += word + break + + # Double check that we're not over the limit + if len(new_desc) > max_size: + new_desc = new_desc[:max_size] + + # We always add an ellipsis because at the very least + # we chopped mid paragraph. + description = new_desc.strip() + "…" + return description if description else None diff --git a/tests/test_preview.py b/tests/test_preview.py new file mode 100644 index 0000000000..2a801173a0 --- /dev/null +++ b/tests/test_preview.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import unittest + +from synapse.rest.media.v1.preview_url_resource import summarize_paragraphs + + +class PreviewTestCase(unittest.TestCase): + + def test_long_summarize(self): + example_paras = [ + """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: + Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in + Troms county, Norway. The administrative centre of the municipality is + the city of Tromsø. Outside of Norway, Tromso and Tromsö are + alternative spellings of the city.Tromsø is considered the northernmost + city in the world with a population above 50,000. The most populous town + north of it is Alta, Norway, with a population of 14,272 (2013).""", + + """Tromsø lies in Northern Norway. The municipality has a population of + (2015) 72,066, but with an annual influx of students it has over 75,000 + most of the year. It is the largest urban area in Northern Norway and the + third largest north of the Arctic Circle (following Murmansk and Norilsk). + Most of Tromsø, including the city centre, is located on the island of + Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012, + Tromsøya had a population of 36,088. Substantial parts of the urban area + are also situated on the mainland to the east, and on parts of Kvaløya—a + large island to the west. Tromsøya is connected to the mainland by the Tromsø + Bridge and the Tromsøysund Tunnel, and to the island of Kvaløya by the + Sandnessund Bridge. Tromsø Airport connects the city to many destinations + in Europe. The city is warmer than most other places located on the same + latitude, due to the warming effect of the Gulf Stream.""", + + """The city centre of Tromsø contains the highest number of old wooden + houses in Northern Norway, the oldest house dating from 1789. The Arctic + Cathedral, a modern church from 1965, is probably the most famous landmark + in Tromsø. The city is a cultural centre for its region, with several + festivals taking place in the summer. Some of Norway's best-known + musicians, Torbjørn Brundtland and Svein Berge of the electronica duo + Röyksopp and Lene Marlin grew up and started their careers in Tromsø. + Noted electronic musician Geir Jenssen also hails from Tromsø.""", + ] + + desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) + + self.assertEquals( + desc, + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway. The administrative centre of the municipality is" + " the city of Tromsø. Outside of Norway, Tromso and Tromsö are" + " alternative spellings of the city.Tromsø is considered the northernmost" + " city in the world with a population above 50,000. The most populous town" + " north of it is Alta, Norway, with a population of 14,272 (2013)." + ) + + desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) + + self.assertEquals( + desc, + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year. It is the largest urban area in Northern Norway and the" + " third largest north of the Arctic Circle (following Murmansk and Norilsk)." + " Most of Tromsø, including the city centre, is located on the island of" + " Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," + " Tromsøya had a population of 36,088. Substantial parts of the…" + ) + + def test_short_summarize(self): + example_paras = [ + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.", + + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year.", + + "The city centre of Tromsø contains the highest number of old wooden" + " houses in Northern Norway, the oldest house dating from 1789. The Arctic" + " Cathedral, a modern church from 1965, is probably the most famous landmark" + " in Tromsø.", + ] + + desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) + + self.assertEquals( + desc, + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.\n" + "\n" + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year." + ) + + def test_small_then_large_summarize(self): + example_paras = [ + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.", + + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year." + " The city centre of Tromsø contains the highest number of old wooden" + " houses in Northern Norway, the oldest house dating from 1789. The Arctic" + " Cathedral, a modern church from 1965, is probably the most famous landmark" + " in Tromsø.", + ] + + desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) + self.assertEquals( + desc, + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.\n" + "\n" + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year. The city centre of Tromsø contains the highest number" + " of old wooden houses in Northern Norway, the oldest house dating from" + " 1789. The Arctic Cathedral, a modern church…" + ) -- cgit 1.4.1