From 887c6e6f052e1dc5e61a0b4bade8e7bd3a63e275 Mon Sep 17 00:00:00 2001 From: David Baker Date: Tue, 31 May 2016 11:05:16 +0100 Subject: Split out the room list handler So I can use it from federation bits without pulling in all the handlers. --- synapse/rest/client/v1/room.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 644aa4e513..2d22bbdaa3 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -279,7 +279,7 @@ class PublicRoomListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - handler = self.handlers.room_list_handler + handler = self.hs.get_room_list_handler() data = yield handler.get_public_room_list() defer.returnValue((200, data)) -- cgit 1.5.1 From d240796dedcfae1f6929c1501e7e335df417cfaf Mon Sep 17 00:00:00 2001 From: David Baker Date: Tue, 31 May 2016 17:20:07 +0100 Subject: Basic, un-cached support for secondary_directory_servers --- synapse/federation/federation_client.py | 21 +++++++++++++++++++++ synapse/federation/transport/client.py | 12 ++++++++++++ synapse/federation/transport/server.py | 2 +- synapse/handlers/room.py | 33 ++++++++++++++++++++++++++++++++- synapse/rest/client/v1/room.py | 3 ++- 5 files changed, 68 insertions(+), 3 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 37ee469fa2..ba8d71c050 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -24,6 +24,7 @@ from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, ) from synapse.util import unwrapFirstError +from synapse.util.async import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.events import FrozenEvent @@ -550,6 +551,26 @@ class FederationClient(FederationBase): raise RuntimeError("Failed to send to any server.") + @defer.inlineCallbacks + def get_public_rooms(self, destinations): + results_by_server = {} + + @defer.inlineCallbacks + def _get_result(s): + if s == self.server_name: + defer.returnValue() + + try: + result = yield self.transport_layer.get_public_rooms(s) + results_by_server[s] = result + except: + logger.exception("Error getting room list from server %r", s) + + + yield concurrently_execute(_get_result, destinations, 3) + + defer.returnValue(results_by_server) + @defer.inlineCallbacks def query_auth(self, destination, room_id, event_id, local_auth): """ diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index cd2841c4db..ebb698e278 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -224,6 +224,18 @@ class TransportLayerClient(object): defer.returnValue(response) + @defer.inlineCallbacks + @log_function + def get_public_rooms(self, remote_server): + path = PREFIX + "/publicRooms" + + response = yield self.client.get_json( + destination=remote_server, + path=path, + ) + + defer.returnValue(response) + @defer.inlineCallbacks @log_function def exchange_third_party_invite(self, destination, room_id, event_dict): diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index f23c02efde..da9e7a326d 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -527,7 +527,7 @@ class PublicRoomList(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, request): - data = yield self.room_list_handler.get_public_room_list() + data = yield self.room_list_handler.get_local_public_room_list() defer.returnValue((200, data)) # Avoid doing remote HS authorization checks which are done by default by diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3d63b3c513..b0aa9fb511 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -345,7 +345,7 @@ class RoomListHandler(BaseHandler): super(RoomListHandler, self).__init__(hs) self.response_cache = ResponseCache() - def get_public_room_list(self): + def get_local_public_room_list(self): result = self.response_cache.get(()) if not result: result = self.response_cache.set((), self._get_public_room_list()) @@ -427,6 +427,37 @@ class RoomListHandler(BaseHandler): # FIXME (erikj): START is no longer a valid value defer.returnValue({"start": "START", "end": "END", "chunk": results}) + @defer.inlineCallbacks + def get_aggregated_public_room_list(self): + """ + Get the public room list from this server and the servers + specified in the secondary_directory_servers config option. + XXX: Pagination... + """ + federated_by_server = yield self.hs.get_replication_layer().get_public_rooms( + self.hs.config.secondary_directory_servers + ) + public_rooms = yield self.get_local_public_room_list() + + # keep track of which room IDs we've seen so we can de-dup + room_ids = set() + + # tag all the ones in our list with our server name. + # Also add the them to the de-deping set + for room in public_rooms['chunk']: + room["server_name"] = self.hs.hostname + room_ids.add(room["room_id"]) + + # Now add the results from federation + for server_name, server_result in federated_by_server.items(): + for room in server_result["chunk"]: + if room["room_id"] not in room_ids: + room["server_name"] = server_name + public_rooms["chunk"].append(room) + room_ids.add(room["room_id"]) + + defer.returnValue(public_rooms) + class RoomContextHandler(BaseHandler): @defer.inlineCallbacks diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 2d22bbdaa3..db52a1fc39 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -280,7 +280,8 @@ class PublicRoomListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): handler = self.hs.get_room_list_handler() - data = yield handler.get_public_room_list() + data = yield handler.get_aggregated_public_room_list() + defer.returnValue((200, data)) -- cgit 1.5.1 From 991af8b0d6406b633386384d823e5c3a9c2ceb8b Mon Sep 17 00:00:00 2001 From: David Baker Date: Wed, 1 Jun 2016 17:40:52 +0100 Subject: WIP on unsubscribing email notifs without logging in --- synapse/api/auth.py | 25 +++++++++++------- synapse/rest/client/v1/pusher.py | 55 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 70 insertions(+), 10 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 2474a1453b..2ece59bb19 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This module contains classes for authenticating the user.""" from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json, SignatureVerifyException @@ -42,13 +41,20 @@ AuthEventTypes = ( class Auth(object): - + """ + FIXME: This class contains a mix of functions for authenticating users + of our client-server API and authenticating events added to room graphs. + """ def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() self.store = hs.get_datastore() self.state = hs.get_state_handler() self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 + # Docs for these currently lives at + # https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst + # In addition, we have type == delete_pusher which grants access only to + # delete pushers. self._KNOWN_CAVEAT_PREFIXES = set([ "gen = ", "guest = ", @@ -507,7 +513,7 @@ class Auth(object): return default @defer.inlineCallbacks - def get_user_by_req(self, request, allow_guest=False): + def get_user_by_req(self, request, allow_guest=False, rights="access"): """ Get a registered user's ID. Args: @@ -529,7 +535,7 @@ class Auth(object): ) access_token = request.args["access_token"][0] - user_info = yield self.get_user_by_access_token(access_token) + user_info = yield self.get_user_by_access_token(access_token, rights) user = user_info["user"] token_id = user_info["token_id"] is_guest = user_info["is_guest"] @@ -590,7 +596,7 @@ class Auth(object): defer.returnValue(user_id) @defer.inlineCallbacks - def get_user_by_access_token(self, token): + def get_user_by_access_token(self, token, rights="access"): """ Get a registered user's ID. Args: @@ -601,7 +607,7 @@ class Auth(object): AuthError if no user by that token exists or the token is invalid. """ try: - ret = yield self.get_user_from_macaroon(token) + ret = yield self.get_user_from_macaroon(token, rights) except AuthError: # TODO(daniel): Remove this fallback when all existing access tokens # have been re-issued as macaroons. @@ -609,11 +615,11 @@ class Auth(object): defer.returnValue(ret) @defer.inlineCallbacks - def get_user_from_macaroon(self, macaroon_str): + def get_user_from_macaroon(self, macaroon_str, rights="access"): try: macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) - self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token) + self.validate_macaroon(macaroon, rights, self.hs.config.expire_access_token) user_prefix = "user_id = " user = None @@ -667,7 +673,8 @@ class Auth(object): Args: macaroon(pymacaroons.Macaroon): The macaroon to validate - type_string(str): The kind of token this is (e.g. "access", "refresh") + type_string(str): The kind of token required (e.g. "access", "refresh", + "delete_pusher") verify_expiry(bool): Whether to verify whether the macaroon has expired. This should really always be True, but no clients currently implement token refresh, so we can't enforce expiry yet. diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index ab928a16da..fa7a0992dd 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -17,7 +17,11 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.push import PusherConfigException -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import ( + parse_json_object_from_request, parse_string, RestServlet +) +from synapse.http.server import finish_request +from synapse.api.errors import StoreError from .base import ClientV1RestServlet, client_path_patterns @@ -136,6 +140,55 @@ class PushersSetRestServlet(ClientV1RestServlet): return 200, {} +class PushersRemoveRestServlet(RestServlet): + """ + To allow pusher to be delete by clicking a link (ie. GET request) + """ + PATTERNS = client_path_patterns("/pushers/remove$") + SUCCESS_HTML = "You have been unsubscribed" + + def __init__(self, hs): + super(RestServlet, self).__init__() + self.notifier = hs.get_notifier() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request, "delete_pusher") + user = requester.user + + app_id = parse_string(request, "app_id", required=True) + pushkey = parse_string(request, "pushkey", required=True) + + pusher_pool = self.hs.get_pusherpool() + + try: + yield pusher_pool.remove_pusher( + app_id=app_id, + pushkey=pushkey, + user_id=user.to_string(), + ) + except StoreError as se: + if se.code != 404: + # This is fine: they're already unsubscribed + raise + + self.notifier.on_new_replication_data() + + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Server", self.hs.version_string) + request.setHeader(b"Content-Length", b"%d" % ( + len(PushersRemoveRestServlet.SUCCESS_HTML), + )) + request.write(PushersRemoveRestServlet.SUCCESS_HTML) + finish_request(request) + defer.returnValue(None) + + def on_OPTIONS(self, _): + return 200, {} + + def register_servlets(hs, http_server): PushersRestServlet(hs).register(http_server) PushersSetRestServlet(hs).register(http_server) + PushersRemoveRestServlet(hs).register(http_server) -- cgit 1.5.1 From 4a10510cd5aff790127a185ecefc83b881a717cc Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 2 Jun 2016 13:31:45 +0100 Subject: Split out the auth handler --- synapse/handlers/__init__.py | 2 -- synapse/handlers/register.py | 2 +- synapse/rest/client/v1/login.py | 11 ++++++----- synapse/rest/client/v2_alpha/account.py | 4 ++-- synapse/rest/client/v2_alpha/auth.py | 2 +- synapse/rest/client/v2_alpha/register.py | 2 +- synapse/rest/client/v2_alpha/tokenrefresh.py | 2 +- synapse/server.py | 5 +++++ tests/rest/client/v2_alpha/test_register.py | 2 +- tests/utils.py | 15 +++++---------- 10 files changed, 23 insertions(+), 24 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index c0069e23d6..d28e07f0d9 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -24,7 +24,6 @@ from .federation import FederationHandler from .profile import ProfileHandler from .directory import DirectoryHandler from .admin import AdminHandler -from .auth import AuthHandler from .identity import IdentityHandler from .receipts import ReceiptsHandler from .search import SearchHandler @@ -50,7 +49,6 @@ class Handlers(object): self.directory_handler = DirectoryHandler(hs) self.admin_handler = AdminHandler(hs) self.receipts_handler = ReceiptsHandler(hs) - self.auth_handler = AuthHandler(hs) self.identity_handler = IdentityHandler(hs) self.search_handler = SearchHandler(hs) self.room_context_handler = RoomContextHandler(hs) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 16f33f8371..bbc07b045e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -413,7 +413,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue((user_id, token)) def auth_handler(self): - return self.hs.get_handlers().auth_handler + return self.hs.get_auth_handler() @defer.inlineCallbacks def guest_access_token_for(self, medium, address, inviter_user_id): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 3b5544851b..8df9d10efa 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -58,6 +58,7 @@ class LoginRestServlet(ClientV1RestServlet): self.cas_required_attributes = hs.config.cas_required_attributes self.servername = hs.config.server_name self.http_client = hs.get_simple_http_client() + self.auth_handler = self.hs.get_auth_handler() def on_GET(self, request): flows = [] @@ -143,7 +144,7 @@ class LoginRestServlet(ClientV1RestServlet): user_id, self.hs.hostname ).to_string() - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_id, access_token, refresh_token = yield auth_handler.login_with_password( user_id=user_id, password=login_submission["password"]) @@ -160,7 +161,7 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def do_token_login(self, login_submission): token = login_submission['token'] - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) @@ -194,7 +195,7 @@ class LoginRestServlet(ClientV1RestServlet): raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_exists = yield auth_handler.does_user_exist(user_id) if user_exists: user_id, access_token, refresh_token = ( @@ -243,7 +244,7 @@ class LoginRestServlet(ClientV1RestServlet): raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_exists = yield auth_handler.does_user_exist(user_id) if user_exists: user_id, access_token, refresh_token = ( @@ -412,7 +413,7 @@ class CasTicketServlet(ClientV1RestServlet): raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_exists = yield auth_handler.does_user_exist(user_id) if not user_exists: user_id, _ = ( diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index c88c270537..9a84873a5f 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -35,7 +35,7 @@ class PasswordRestServlet(RestServlet): super(PasswordRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() @defer.inlineCallbacks def on_POST(self, request): @@ -97,7 +97,7 @@ class ThreepidRestServlet(RestServlet): self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 78181b7b18..58d3cad6a1 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet): super(AuthRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler @defer.inlineCallbacks diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 1ecc02d94d..2088c316d1 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -49,7 +49,7 @@ class RegisterRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - self.auth_handler = hs.get_handlers().auth_handler + self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index a158c2209a..8270e8787f 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -38,7 +38,7 @@ class TokenRefreshRestServlet(RestServlet): body = parse_json_object_from_request(request) try: old_refresh_token = body["refresh_token"] - auth_handler = self.hs.get_handlers().auth_handler + auth_handler = self.hs.get_auth_handler() (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( old_refresh_token, auth_handler.generate_refresh_token) new_access_token = yield auth_handler.issue_access_token(user_id) diff --git a/synapse/server.py b/synapse/server.py index 7cf22b1eea..dd4b81c658 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -33,6 +33,7 @@ from synapse.handlers.presence import PresenceHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler from synapse.handlers.room import RoomListHandler +from synapse.handlers.auth import AuthHandler from synapse.handlers.appservice import ApplicationServicesHandler from synapse.state import StateHandler from synapse.storage import DataStore @@ -89,6 +90,7 @@ class HomeServer(object): 'sync_handler', 'typing_handler', 'room_list_handler', + 'auth_handler', 'application_service_api', 'application_service_scheduler', 'application_service_handler', @@ -190,6 +192,9 @@ class HomeServer(object): def build_room_list_handler(self): return RoomListHandler(self) + def build_auth_handler(self): + return AuthHandler(self) + def build_application_service_api(self): return ApplicationServiceApi(self) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index affd42c015..cda0a2b27c 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -33,7 +33,6 @@ class RegisterRestServletTestCase(unittest.TestCase): # do the dance to hook it up to the hs global self.handlers = Mock( - auth_handler=self.auth_handler, registration_handler=self.registration_handler, identity_handler=self.identity_handler, login_handler=self.login_handler @@ -42,6 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.hs.hostname = "superbig~testing~thing.com" self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) + self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.config.enable_registration = True # init the thing we're testing diff --git a/tests/utils.py b/tests/utils.py index 006abedbc1..e19ae581e0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -81,16 +81,11 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): ) # bcrypt is far too slow to be doing in unit tests - def swap_out_hash_for_testing(old_build_handlers): - def build_handlers(): - handlers = old_build_handlers() - auth_handler = handlers.auth_handler - auth_handler.hash = lambda p: hashlib.md5(p).hexdigest() - auth_handler.validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h - return handlers - return build_handlers - - hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers) + # Need to let the HS build an auth handler and then mess with it + # because AuthHandler's constructor requires the HS, so we can't make one + # beforehand and pass it in to the HS's constructor (chicken / egg) + hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest() + hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h fed = kargs.get("resource_for_federation", None) if fed: -- cgit 1.5.1 From 1f31cc37f8611f9ae5612ef5be82e63735fbdf34 Mon Sep 17 00:00:00 2001 From: David Baker Date: Thu, 2 Jun 2016 17:21:31 +0100 Subject: Working unsubscribe links going straight to the HS and authed by macaroons that let you delete pushers and nothing else --- synapse/api/auth.py | 7 +++++++ synapse/app/pusher.py | 23 ++++++++++++++++++++++- synapse/push/mailer.py | 8 ++++---- synapse/rest/client/v1/pusher.py | 4 +++- 4 files changed, 36 insertions(+), 6 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 463bd8b692..31e1abb964 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -660,6 +660,13 @@ class Auth(object): "is_guest": True, "token_id": None, } + elif rights == "delete_pusher": + # We don't store these tokens in the database + ret = { + "user": user, + "is_guest": False, + "token_id": None, + } else: # This codepath exists so that we can actually return a # token ID, because we use token IDs in place of device diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 135dd58c15..f1de1e7ce9 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -21,6 +21,7 @@ from synapse.config._base import ConfigError from synapse.config.database import DatabaseConfig from synapse.config.logger import LoggingConfig from synapse.config.emailconfig import EmailConfig +from synapse.config.key import KeyConfig from synapse.http.site import SynapseSite from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.storage.roommember import RoomMemberStore @@ -63,6 +64,26 @@ class SlaveConfig(DatabaseConfig): self.pid_file = self.abspath(config.get("pid_file")) self.public_baseurl = config["public_baseurl"] + # some things used by the auth handler but not actually used in the + # pusher codebase + self.bcrypt_rounds = None + self.ldap_enabled = None + self.ldap_server = None + self.ldap_port = None + self.ldap_tls = None + self.ldap_search_base = None + self.ldap_search_property = None + self.ldap_email_property = None + self.ldap_full_name_property = None + + # We would otherwise try to use the registration shared secret as the + # macaroon shared secret if there was no macaroon_shared_secret, but + # that means pulling in RegistrationConfig too. We don't need to be + # backwards compaitible in the pusher codebase so just make people set + # macaroon_shared_secret. We set this to None to prevent it referencing + # an undefined key. + self.registration_shared_secret = None + def default_config(self, server_name, **kwargs): pid_file = self.abspath("pusher.pid") return """\ @@ -95,7 +116,7 @@ class SlaveConfig(DatabaseConfig): """ % locals() -class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig): +class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig, KeyConfig): pass diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index e877d8fdad..60d3700afa 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -81,7 +81,7 @@ class Mailer(object): def __init__(self, hs, app_name): self.hs = hs self.store = self.hs.get_datastore() - self.handlers = self.hs.get_handlers() + self.auth_handler = self.hs.get_auth_handler() self.state_handler = self.hs.get_state_handler() loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir) self.app_name = app_name @@ -161,7 +161,7 @@ class Mailer(object): template_vars = { "user_display_name": user_display_name, - "unsubscribe_link": self.make_unsubscribe_link(app_id, email_address), + "unsubscribe_link": self.make_unsubscribe_link(user_id, app_id, email_address), "summary_text": summary_text, "app_name": self.app_name, "rooms": rooms, @@ -427,9 +427,9 @@ class Mailer(object): notif['room_id'], notif['event_id'] ) - def make_unsubscribe_link(self, app_id, email_address): + def make_unsubscribe_link(self, user_id, app_id, email_address): params = { - "access_token": self.handlers.auth.generate_delete_pusher_token(), + "access_token": self.auth_handler.generate_delete_pusher_token(user_id), "app_id": app_id, "pushkey": email_address, } diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index fa7a0992dd..9a2ed6ed88 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -149,11 +149,13 @@ class PushersRemoveRestServlet(RestServlet): def __init__(self, hs): super(RestServlet, self).__init__() + self.hs = hs self.notifier = hs.get_notifier() + self.auth = hs.get_v1auth() @defer.inlineCallbacks def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, "delete_pusher") + requester = yield self.auth.get_user_by_req(request, rights="delete_pusher") user = requester.user app_id = parse_string(request, "app_id", required=True) -- cgit 1.5.1 From 6a0afa582aa5bf816e082af31ac44e2a8fee28c0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Jun 2016 14:27:07 +0100 Subject: Load push rules in storage layer, so that they get cached --- synapse/handlers/sync.py | 5 ++--- synapse/push/bulk_push_rule_evaluator.py | 28 ----------------------- synapse/push/clientformat.py | 30 ++++++++++++++++++------- synapse/rest/client/v1/push_rule.py | 6 ++--- synapse/storage/push_rule.py | 38 +++++++++++++++++++++++++++++++- 5 files changed, 63 insertions(+), 44 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5307b62b85..be26a491ff 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -198,9 +198,8 @@ class SyncHandler(object): @defer.inlineCallbacks def push_rules_for_user(self, user): user_id = user.to_string() - rawrules = yield self.store.get_push_rules_for_user(user_id) - enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) - rules = format_push_rules_for_user(user, rawrules, enabled_map) + rules = yield self.store.get_push_rules_for_user(user_id) + rules = format_push_rules_for_user(user, rules) defer.returnValue(rules) @defer.inlineCallbacks diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index af5212a5d1..6e42121b1d 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -18,7 +18,6 @@ import ujson as json from twisted.internet import defer -from .baserules import list_with_base_rules from .push_rule_evaluator import PushRuleEvaluatorForEvent from synapse.api.constants import EventTypes, Membership @@ -38,36 +37,9 @@ def decode_rule_json(rule): @defer.inlineCallbacks def _get_rules(room_id, user_ids, store): rules_by_user = yield store.bulk_get_push_rules(user_ids) - rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids) rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} - rules_by_user = { - uid: list_with_base_rules([ - decode_rule_json(rule_list) - for rule_list in rules_by_user.get(uid, []) - ]) - for uid in user_ids - } - - # We apply the rules-enabled map here: bulk_get_push_rules doesn't - # fetch disabled rules, but this won't account for any server default - # rules the user has disabled, so we need to do this too. - for uid in user_ids: - user_enabled_map = rules_enabled_by_user.get(uid) - if not user_enabled_map: - continue - - for i, rule in enumerate(rules_by_user[uid]): - rule_id = rule['rule_id'] - - if rule_id in user_enabled_map: - if rule.get('enabled', True) != bool(user_enabled_map[rule_id]): - # Rules are cached across users. - rule = dict(rule) - rule['enabled'] = bool(user_enabled_map[rule_id]) - rules_by_user[uid][i] = rule - defer.returnValue(rules_by_user) diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index ae9db9ec2f..b3983f7940 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -23,10 +23,7 @@ import copy import simplejson as json -def format_push_rules_for_user(user, rawrules, enabled_map): - """Converts a list of rawrules and a enabled map into nested dictionaries - to match the Matrix client-server format for push rules""" - +def load_rules_for_user(user, rawrules, enabled_map): ruleslist = [] for rawrule in rawrules: rule = dict(rawrule) @@ -35,7 +32,26 @@ def format_push_rules_for_user(user, rawrules, enabled_map): ruleslist.append(rule) # We're going to be mutating this a lot, so do a deep copy - ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) + rules = list(list_with_base_rules(ruleslist)) + + for i, rule in enumerate(rules): + rule_id = rule['rule_id'] + if rule_id in enabled_map: + if rule.get('enabled', True) != bool(enabled_map[rule_id]): + # Rules are cached across users. + rule = dict(rule) + rule['enabled'] = bool(enabled_map[rule_id]) + rules[i] = rule + + return rules + + +def format_push_rules_for_user(user, ruleslist): + """Converts a list of rawrules and a enabled map into nested dictionaries + to match the Matrix client-server format for push rules""" + + # We're going to be mutating this a lot, so do a deep copy + ruleslist = copy.deepcopy(ruleslist) rules = {'global': {}, 'device': {}} @@ -60,9 +76,7 @@ def format_push_rules_for_user(user, rawrules, enabled_map): template_rule = _rule_to_template(r) if template_rule: - if r['rule_id'] in enabled_map: - template_rule['enabled'] = enabled_map[r['rule_id']] - elif 'enabled' in r: + if 'enabled' in r: template_rule['enabled'] = r['enabled'] else: template_rule['enabled'] = True diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 02d837ee6a..6bb4821ec6 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -128,11 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet): # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rawrules = yield self.store.get_push_rules_for_user(user_id) + rules = yield self.store.get_push_rules_for_user(user_id) - enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) - - rules = format_push_rules_for_user(requester.user, rawrules, enabled_map) + rules = format_push_rules_for_user(requester.user, rules) path = request.postpath[1:] diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index ebb97c8474..786d6f6d67 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -15,6 +15,7 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.push.baserules import list_with_base_rules from twisted.internet import defer import logging @@ -23,6 +24,29 @@ import simplejson as json logger = logging.getLogger(__name__) +def _load_rules(rawrules, enabled_map): + ruleslist = [] + for rawrule in rawrules: + rule = dict(rawrule) + rule["conditions"] = json.loads(rawrule["conditions"]) + rule["actions"] = json.loads(rawrule["actions"]) + ruleslist.append(rule) + + # We're going to be mutating this a lot, so do a deep copy + rules = list(list_with_base_rules(ruleslist)) + + for i, rule in enumerate(rules): + rule_id = rule['rule_id'] + if rule_id in enabled_map: + if rule.get('enabled', True) != bool(enabled_map[rule_id]): + # Rules are cached across users. + rule = dict(rule) + rule['enabled'] = bool(enabled_map[rule_id]) + rules[i] = rule + + return rules + + class PushRuleStore(SQLBaseStore): @cachedInlineCallbacks(lru=True) def get_push_rules_for_user(self, user_id): @@ -42,7 +66,11 @@ class PushRuleStore(SQLBaseStore): key=lambda row: (-int(row["priority_class"]), -int(row["priority"])) ) - defer.returnValue(rows) + enabled_map = yield self.get_push_rules_enabled_for_user(user_id) + + rules = _load_rules(rows, enabled_map) + + defer.returnValue(rules) @cachedInlineCallbacks(lru=True) def get_push_rules_enabled_for_user(self, user_id): @@ -85,6 +113,14 @@ class PushRuleStore(SQLBaseStore): for row in rows: results.setdefault(row['user_name'], []).append(row) + + enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) + + for user_id, rules in results.items(): + results[user_id] = _load_rules( + rules, enabled_map_by_user.get(user_id, {}) + ) + defer.returnValue(results) @cachedList(cached_method_name="get_push_rules_enabled_for_user", -- cgit 1.5.1 From efeabd31801224cbacd31b61ff0d869b70b1820d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 8 Jun 2016 14:23:15 +0100 Subject: Log user that is making /publicRooms calls --- synapse/rest/client/v1/room.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'synapse/rest/client/v1') diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index db52a1fc39..604c2a565e 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -279,6 +279,13 @@ class PublicRoomListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): + try: + yield self.auth.get_user_by_req(request) + except AuthError: + # This endpoint isn't authed, but its useful to know who's hitting + # it if they *do* supply an access token + pass + handler = self.hs.get_room_list_handler() data = yield handler.get_aggregated_public_room_list() -- cgit 1.5.1 From 690029d1a3ebd26f56656a723fefdeafd71310e4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 8 Jun 2016 14:47:42 +0100 Subject: Don't make rooms visibile by default --- synapse/rest/client/v1/room.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 604c2a565e..86fbe2747d 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -72,8 +72,6 @@ class RoomCreateRestServlet(ClientV1RestServlet): def get_room_config(self, request): user_supplied_config = parse_json_object_from_request(request) - # default visibility - user_supplied_config.setdefault("visibility", "public") return user_supplied_config def on_OPTIONS(self, request): -- cgit 1.5.1 From 95f305c35a790e8f10fef7e16268dfaba6bc4c31 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 9 Jun 2016 10:57:11 +0100 Subject: Remove redundant exception log in /events --- synapse/rest/client/v1/events.py | 45 +++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 24 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index d1afa0f0d5..498bb9e18a 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -45,30 +45,27 @@ class EventStreamRestServlet(ClientV1RestServlet): raise SynapseError(400, "Guest users must specify room_id param") if "room_id" in request.args: room_id = request.args["room_id"][0] - try: - handler = self.handlers.event_stream_handler - pagin_config = PaginationConfig.from_request(request) - timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS - if "timeout" in request.args: - try: - timeout = int(request.args["timeout"][0]) - except ValueError: - raise SynapseError(400, "timeout must be in milliseconds.") - - as_client_event = "raw" not in request.args - - chunk = yield handler.get_stream( - requester.user.to_string(), - pagin_config, - timeout=timeout, - as_client_event=as_client_event, - affect_presence=(not is_guest), - room_id=room_id, - is_guest=is_guest, - ) - except: - logger.exception("Event stream failed") - raise + + handler = self.handlers.event_stream_handler + pagin_config = PaginationConfig.from_request(request) + timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS + if "timeout" in request.args: + try: + timeout = int(request.args["timeout"][0]) + except ValueError: + raise SynapseError(400, "timeout must be in milliseconds.") + + as_client_event = "raw" not in request.args + + chunk = yield handler.get_stream( + requester.user.to_string(), + pagin_config, + timeout=timeout, + as_client_event=as_client_event, + affect_presence=(not is_guest), + room_id=room_id, + is_guest=is_guest, + ) defer.returnValue((200, chunk)) -- cgit 1.5.1 From a70688445dd7a9fa41a55a642fb9a394f291ae45 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 Jun 2016 14:57:59 +0100 Subject: Implement purge_media_cache admin API --- synapse/rest/client/v1/admin.py | 32 +++++++++++++ synapse/rest/media/v1/filepath.py | 6 +++ synapse/rest/media/v1/media_repository.py | 78 +++++++++++++++++++++++-------- synapse/server.py | 5 ++ synapse/storage/media_repository.py | 29 ++++++++++++ 5 files changed, 130 insertions(+), 20 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index aa05b3f023..8ec8569a49 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -46,5 +46,37 @@ class WhoisRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) +class PurgeMediaCacheRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/admin/purge_media_cache") + + def __init__(self, hs): + self.media_repository = hs.get_media_repository() + super(PurgeMediaCacheRestServlet, self).__init__(hs) + + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + before_ts = request.args.get("before_ts", None) + if not before_ts: + raise SynapseError(400, "Missing 'before_ts' arg") + + logger.info("before_ts: %r", before_ts[0]) + + try: + before_ts = int(before_ts[0]) + except Exception: + raise SynapseError(400, "Invalid 'before_ts' arg") + + ret = yield self.media_repository.delete_old_remote_media(before_ts) + + defer.returnValue((200, ret)) + + def register_servlets(hs, http_server): WhoisRestServlet(hs).register(http_server) + PurgeMediaCacheRestServlet(hs).register(http_server) diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 422ab86fb3..0137458f71 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -65,3 +65,9 @@ class MediaFilePaths(object): file_id[0:2], file_id[2:4], file_id[4:], file_name ) + + def remote_media_thumbnail_dir(self, server_name, file_id): + return os.path.join( + self.base_path, "remote_thumbnail", server_name, + file_id[0:2], file_id[2:4], file_id[4:], + ) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 1a287b6fec..844628c121 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -30,11 +30,13 @@ from synapse.api.errors import SynapseError from twisted.internet import defer, threads -from synapse.util.async import ObservableDeferred +from synapse.util.async import Linearizer from synapse.util.stringutils import is_ascii from synapse.util.logcontext import preserve_context_over_fn import os +import errno +import shutil import cgi import logging @@ -47,7 +49,7 @@ UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000 class MediaRepository(object): - def __init__(self, hs, filepaths): + def __init__(self, hs): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() @@ -55,11 +57,12 @@ class MediaRepository(object): self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels - self.filepaths = filepaths - self.downloads = {} + self.filepaths = MediaFilePaths(hs.config.media_store_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements + self.remote_media_linearizer = Linearizer() + self.recently_accessed_remotes = set() self.clock.looping_call( @@ -112,22 +115,12 @@ class MediaRepository(object): defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) + @defer.inlineCallbacks def get_remote_media(self, server_name, media_id): key = (server_name, media_id) - download = self.downloads.get(key) - if download is None: - download = self._get_remote_media_impl(server_name, media_id) - download = ObservableDeferred( - download, - consumeErrors=True - ) - self.downloads[key] = download - - @download.addBoth - def callback(media_info): - del self.downloads[key] - return media_info - return download.observe() + with (yield self.remote_media_linearizer.queue(key)): + media_info = yield self._get_remote_media_impl(server_name, media_id) + defer.returnValue(media_info) @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): @@ -440,6 +433,52 @@ class MediaRepository(object): "height": m_height, }) + @defer.inlineCallbacks + def delete_old_remote_media(self, before_ts): + old_media = yield self.store.get_remote_media_before(before_ts) + + deleted = 0 + + for media in old_media: + origin = media["media_origin"] + media_id = media["media_id"] + file_id = media["filesystem_id"] + key = (origin, media_id) + + logger.info("Deleting: %r", key) + + with (yield self.remote_media_linearizer.queue(key)): + full_path = self.filepaths.remote_media_filepath(origin, file_id) + full_dir = os.path.dirname(full_path) + try: + os.remove(full_path) + except OSError as e: + logger.warn("Failed to remove file: %r", full_path) + if e.errno == errno.ENOENT: + pass + else: + continue + + try: + os.removedirs(full_dir) + except OSError: + pass + + thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( + origin, file_id + ) + shutil.rmtree(thumbnail_dir, ignore_errors=True) + + yield self.store.delete_remote_media(origin, media_id) + try: + os.removedirs(thumbnail_dir) + except OSError: + pass + + deleted += 1 + + defer.returnValue({"deleted": deleted}) + class MediaRepositoryResource(Resource): """File uploading and downloading. @@ -488,9 +527,8 @@ class MediaRepositoryResource(Resource): def __init__(self, hs): Resource.__init__(self) - filepaths = MediaFilePaths(hs.config.media_store_path) - media_repo = MediaRepository(hs, filepaths) + media_repo = hs.get_media_repository() self.putChild("upload", UploadResource(hs, media_repo)) self.putChild("download", DownloadResource(hs, media_repo)) diff --git a/synapse/server.py b/synapse/server.py index dd4b81c658..d49a1a8a96 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -45,6 +45,7 @@ from synapse.crypto.keyring import Keyring from synapse.push.pusherpool import PusherPool from synapse.events.builder import EventBuilderFactory from synapse.api.filtering import Filtering +from synapse.rest.media.v1.media_repository import MediaRepository from synapse.http.matrixfederationclient import MatrixFederationHttpClient @@ -113,6 +114,7 @@ class HomeServer(object): 'filtering', 'http_client_context_factory', 'simple_http_client', + 'media_repository', ] def __init__(self, hostname, **kwargs): @@ -233,6 +235,9 @@ class HomeServer(object): **self.db_config.get("args", {}) ) + def build_media_repository(self): + return MediaRepository(self) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 44e4d38307..4c0f82353d 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -205,3 +205,32 @@ class MediaRepositoryStore(SQLBaseStore): }, desc="store_remote_media_thumbnail", ) + + def get_remote_media_before(self, before_ts): + sql = ( + "SELECT media_origin, media_id, filesystem_id" + " FROM remote_media_cache" + " WHERE last_access_ts < ?" + ) + + return self._execute( + "get_remote_media_before", self.cursor_to_dict, sql, before_ts + ) + + def delete_remote_media(self, media_origin, media_id): + def delete_remote_media_txn(txn): + self._simple_delete_txn( + txn, + "remote_media_cache", + keyvalues={ + "media_origin": media_origin, "media_id": media_id + }, + ) + self._simple_delete_txn( + txn, + "remote_media_cache_thumbnails", + keyvalues={ + "media_origin": media_origin, "media_id": media_id + }, + ) + return self.runInteraction("delete_remote_media", delete_remote_media_txn) -- cgit 1.5.1 From f328d95cef99763d056171846253ed68cab58214 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 Jun 2016 15:40:58 +0100 Subject: Feature: Add deactivate account admin API Allows server admins to "deactivate" accounts, which: - Revokes all access tokens - Removes all threepids - Removes password The API is a POST to `/admin/deactivate/` --- synapse/rest/client/v1/admin.py | 26 ++++++++++++++++++++++++++ synapse/storage/_base.py | 5 +++++ synapse/storage/registration.py | 9 +++++++++ 3 files changed, 40 insertions(+) (limited to 'synapse/rest/client/v1') diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 8ec8569a49..e54c472e08 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -77,6 +77,32 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) +class DeactivateAccountRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/admin/deactivate/(?P[^/]*)") + + def __init__(self, hs): + self.store = hs.get_datastore() + super(DeactivateAccountRestServlet, self).__init__(hs) + + @defer.inlineCallbacks + def on_POST(self, request, target_user_id): + UserID.from_string(target_user_id) + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + # FIXME: Theoretically there is a race here wherein user resets password + # using threepid. + yield self.store.user_delete_access_tokens(target_user_id) + yield self.store.user_delete_threepids(target_user_id) + yield self.store.user_set_password_hash(target_user_id, None) + + defer.returnValue((200, {})) + + def register_servlets(hs, http_server): WhoisRestServlet(hs).register(http_server) PurgeMediaCacheRestServlet(hs).register(http_server) + DeactivateAccountRestServlet(hs).register(http_server) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 32c6677d47..d766a30299 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -807,6 +807,11 @@ class SQLBaseStore(object): if txn.rowcount > 1: raise StoreError(500, "more than one row matched") + def _simple_delete(self, table, keyvalues, desc): + return self.runInteraction( + desc, self._simple_delete_txn, table, keyvalues + ) + @staticmethod def _simple_delete_txn(txn, table, keyvalues): sql = "DELETE FROM %s WHERE %s" % ( diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 3de9e0f709..5c75dbab51 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -384,6 +384,15 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(ret['user_id']) defer.returnValue(None) + def user_delete_threepids(self, user_id): + return self._simple_delete( + "user_threepids", + keyvalues={ + "user_id": user_id, + }, + desc="user_delete_threepids", + ) + @defer.inlineCallbacks def count_all_users(self): """Counts all users registered on the homeserver.""" -- cgit 1.5.1 From fc8007dbec40212ae85285aea600111ce2d06912 Mon Sep 17 00:00:00 2001 From: Kent Shikama Date: Sun, 3 Jul 2016 15:08:15 +0900 Subject: Optionally include password hash in createUser endpoint Signed-off-by: Kent Shikama --- synapse/handlers/register.py | 4 ++-- synapse/rest/client/v1/register.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0b7517221d..e255f2da81 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -358,7 +358,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, localpart, displayname, duration_seconds): + def get_or_create_user(self, localpart, displayname, duration_seconds, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -394,7 +394,7 @@ class RegistrationHandler(BaseHandler): yield self.store.register( user_id=user_id, token=token, - password_hash=None, + password_hash=password_hash, create_profile_with_localpart=user.localpart, ) else: diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index e3f4fbb0bb..ef56d1e90f 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -410,12 +410,14 @@ class CreateUserRestServlet(ClientV1RestServlet): raise SynapseError(400, "Failed to parse 'duration_seconds'") if duration_seconds > self.direct_user_creation_max_duration: duration_seconds = self.direct_user_creation_max_duration + password_hash = user_json["password_hash"].encode("utf-8") if user_json["password_hash"] else None handler = self.handlers.registration_handler user_id, token = yield handler.get_or_create_user( localpart=localpart, displayname=displayname, - duration_seconds=duration_seconds + duration_seconds=duration_seconds, + password_hash=password_hash ) defer.returnValue({ -- cgit 1.5.1 From 2e5a31f1973b49ec1a89cfc042e00b51ba7e70fc Mon Sep 17 00:00:00 2001 From: Kent Shikama Date: Mon, 4 Jul 2016 22:00:13 +0900 Subject: Use .get() instead of [] to access password_hash --- synapse/rest/client/v1/register.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index ef56d1e90f..a923d5a198 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -410,7 +410,7 @@ class CreateUserRestServlet(ClientV1RestServlet): raise SynapseError(400, "Failed to parse 'duration_seconds'") if duration_seconds > self.direct_user_creation_max_duration: duration_seconds = self.direct_user_creation_max_duration - password_hash = user_json["password_hash"].encode("utf-8") if user_json["password_hash"] else None + password_hash = user_json["password_hash"].encode("utf-8") if user_json.get("password_hash") else None handler = self.handlers.registration_handler user_id, token = yield handler.get_or_create_user( -- cgit 1.5.1 From bb069079bbd0ce761403416ed4f77051352ed347 Mon Sep 17 00:00:00 2001 From: Kent Shikama Date: Mon, 4 Jul 2016 22:07:11 +0900 Subject: Fix style violations Signed-off-by: Kent Shikama --- synapse/handlers/register.py | 3 ++- synapse/rest/client/v1/register.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index e255f2da81..88c82ba7d0 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -358,7 +358,8 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, localpart, displayname, duration_seconds, password_hash=None): + def get_or_create_user(self, localpart, displayname, duration_seconds, + password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index a923d5a198..d791d5e07e 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -410,7 +410,8 @@ class CreateUserRestServlet(ClientV1RestServlet): raise SynapseError(400, "Failed to parse 'duration_seconds'") if duration_seconds > self.direct_user_creation_max_duration: duration_seconds = self.direct_user_creation_max_duration - password_hash = user_json["password_hash"].encode("utf-8") if user_json.get("password_hash") else None + password_hash = user_json["password_hash"].encode("utf-8") \ + if user_json.get("password_hash") else None handler = self.handlers.registration_handler user_id, token = yield handler.get_or_create_user( -- cgit 1.5.1 From 2d21d43c34751cffb5f324bd58ceff060f65f679 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Jul 2016 10:28:51 +0100 Subject: Add purge_history API --- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 13 +++++++++++++ synapse/rest/client/v1/admin.py | 18 ++++++++++++++++++ synapse/storage/events.py | 6 ++++++ 4 files changed, 38 insertions(+), 1 deletion(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 6c0bc7eafa..351b218247 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1413,7 +1413,7 @@ class FederationHandler(BaseHandler): local_view = dict(auth_events) remote_view = dict(auth_events) remote_view.update({ - (d.type, d.state_key): d for d in different_events + (d.type, d.state_key): d for d in different_events if d }) new_state, prev_state = self.state_handler.resolve_events( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 15caf1950a..878809d50d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -50,6 +50,19 @@ class MessageHandler(BaseHandler): self.validator = EventValidator() self.snapshot_cache = SnapshotCache() + @defer.inlineCallbacks + def purge_history(self, room_id, event_id): + event = yield self.store.get_event(event_id) + + if event.room_id != room_id: + raise SynapseError(400, "Event is for wrong room.") + + depth = event.depth + + # TODO: Lock. + + yield self.store.delete_old_state(room_id, depth) + @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, as_client_event=True): diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index e54c472e08..71537a7d0b 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -77,6 +77,24 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) +class PurgeHistoryRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns( + "/admin/purge_history/(?P[^/]*)/(?P[^/]*)" + ) + + @defer.inlineCallbacks + def on_POST(self, request, room_id, event_id): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + yield self.handlers.message_handler.purge_history(room_id, event_id) + + defer.returnValue((200, {})) + + class DeactivateAccountRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/admin/deactivate/(?P[^/]*)") diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 98c917ce15..c3b498bb3d 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1281,6 +1281,12 @@ class EventsStore(SQLBaseStore): ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) + def delete_old_state(self, room_id, topological_ordering): + return self.runInteraction( + "delete_old_state", + self._delete_old_state_txn, room_id, topological_ordering + ) + def _delete_old_state_txn(self, txn, room_id, topological_ordering): """Deletes old room state """ -- cgit 1.5.1 From caf33b2d9be1b992098a00ee61cf4b4009ee3a09 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Jul 2016 17:18:19 +0100 Subject: Protect password when registering using shared secret --- scripts/register_new_matrix_user | 11 ++++++++--- synapse/rest/client/v1/register.py | 11 +++++++---- 2 files changed, 15 insertions(+), 7 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/scripts/register_new_matrix_user b/scripts/register_new_matrix_user index 27a6250b14..6d055fd012 100755 --- a/scripts/register_new_matrix_user +++ b/scripts/register_new_matrix_user @@ -25,12 +25,17 @@ import urllib2 import yaml -def request_registration(user, password, server_location, shared_secret): +def request_registration(user, password, server_location, shared_secret, admin=False): mac = hmac.new( key=shared_secret, - msg=user, digestmod=hashlib.sha1, - ).hexdigest() + ) + + mac.update(user) + mac.update(password) + mac.update("admin" if admin else "notadmin") + + mac = mac.hexdigest() data = { "user": user, diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index d791d5e07e..0eb7490e5d 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -324,6 +324,8 @@ class RegisterRestServlet(ClientV1RestServlet): raise SynapseError(400, "Shared secret registration is not enabled") user = register_json["user"].encode("utf-8") + password = register_json["password"].encode("utf-8") + admin = register_json.get("admin", None) # str() because otherwise hmac complains that 'unicode' does not # have the buffer interface @@ -331,11 +333,12 @@ class RegisterRestServlet(ClientV1RestServlet): want_mac = hmac.new( key=self.hs.config.registration_shared_secret, - msg=user, digestmod=sha1, - ).hexdigest() - - password = register_json["password"].encode("utf-8") + ) + want_mac.update(user) + want_mac.update(password) + want_mac.update("admin" if admin else "notadmin") + want_mac = want_mac.hexdigest() if compare_digest(want_mac, got_mac): handler = self.handlers.registration_handler -- cgit 1.5.1 From 651faee698d5ff4806d1e0e7f5cd4c438bf434f1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 Jul 2016 17:30:22 +0100 Subject: Add an admin option to shared secret registration --- scripts/register_new_matrix_user | 19 ++++++++++-- synapse/handlers/register.py | 4 ++- synapse/rest/client/v1/register.py | 1 + synapse/storage/registration.py | 61 ++++++++++++++++++++++++-------------- 4 files changed, 58 insertions(+), 27 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/scripts/register_new_matrix_user b/scripts/register_new_matrix_user index 6d055fd012..987bf32d1c 100755 --- a/scripts/register_new_matrix_user +++ b/scripts/register_new_matrix_user @@ -42,6 +42,7 @@ def request_registration(user, password, server_location, shared_secret, admin=F "password": password, "mac": mac, "type": "org.matrix.login.shared_secret", + "admin": admin, } server_location = server_location.rstrip("/") @@ -73,7 +74,7 @@ def request_registration(user, password, server_location, shared_secret, admin=F sys.exit(1) -def register_new_user(user, password, server_location, shared_secret): +def register_new_user(user, password, server_location, shared_secret, admin): if not user: try: default_user = getpass.getuser() @@ -104,7 +105,14 @@ def register_new_user(user, password, server_location, shared_secret): print "Passwords do not match" sys.exit(1) - request_registration(user, password, server_location, shared_secret) + if not admin: + admin = raw_input("Make admin [no]: ") + if admin in ("y", "yes", "true"): + admin = True + else: + admin = False + + request_registration(user, password, server_location, shared_secret, bool(admin)) if __name__ == "__main__": @@ -124,6 +132,11 @@ if __name__ == "__main__": default=None, help="New password for user. Will prompt if omitted.", ) + parser.add_argument( + "-a", "--admin", + action="store_true", + help="Register new user as an admin. Will prompt if omitted.", + ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument( @@ -156,4 +169,4 @@ if __name__ == "__main__": else: secret = args.shared_secret - register_new_user(args.user, args.password, args.server_url, secret) + register_new_user(args.user, args.password, args.server_url, secret, args.admin) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 88c82ba7d0..8c3381df8a 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -90,7 +90,8 @@ class RegistrationHandler(BaseHandler): password=None, generate_token=True, guest_access_token=None, - make_guest=False + make_guest=False, + admin=False, ): """Registers a new client on the server. @@ -141,6 +142,7 @@ class RegistrationHandler(BaseHandler): # If the user was a guest then they already have a profile None if was_guest else user.localpart ), + admin=admin, ) else: # autogen a sequential user ID diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 0eb7490e5d..25d63a0b0b 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -345,6 +345,7 @@ class RegisterRestServlet(ClientV1RestServlet): user_id, token = yield handler.register( localpart=user, password=password, + admin=bool(admin), ) self._remove_session(session) defer.returnValue({ diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 5c75dbab51..4999175ddb 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -77,7 +77,7 @@ class RegistrationStore(SQLBaseStore): @defer.inlineCallbacks def register(self, user_id, token, password_hash, was_guest=False, make_guest=False, appservice_id=None, - create_profile_with_localpart=None): + create_profile_with_localpart=None, admin=False): """Attempts to register an account. Args: @@ -104,6 +104,7 @@ class RegistrationStore(SQLBaseStore): make_guest, appservice_id, create_profile_with_localpart, + admin ) self.get_user_by_id.invalidate((user_id,)) self.is_guest.invalidate((user_id,)) @@ -118,6 +119,7 @@ class RegistrationStore(SQLBaseStore): make_guest, appservice_id, create_profile_with_localpart, + admin, ): now = int(self.clock.time()) @@ -125,29 +127,42 @@ class RegistrationStore(SQLBaseStore): try: if was_guest: - txn.execute("UPDATE users SET" - " password_hash = ?," - " upgrade_ts = ?," - " is_guest = ?" - " WHERE name = ?", - [password_hash, now, 1 if make_guest else 0, user_id]) + txn.execute( + "UPDATE users SET" + " password_hash = ?," + " upgrade_ts = ?," + " is_guest = ?," + " admin = ?" + " WHERE name = ?", + (password_hash, now, 1 if make_guest else 0, admin, user_id,) + ) + self._simple_update_one_txn( + txn, + "users", + keyvalues={ + "name": user_id, + }, + updatevalues={ + "password_hash": password_hash, + "upgrade_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": admin, + } + ) else: - txn.execute("INSERT INTO users " - "(" - " name," - " password_hash," - " creation_ts," - " is_guest," - " appservice_id" - ") " - "VALUES (?,?,?,?,?)", - [ - user_id, - password_hash, - now, - 1 if make_guest else 0, - appservice_id, - ]) + self._simple_insert_txn( + txn, + "users", + values={ + "name": user_id, + "password_hash": password_hash, + "creation_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": admin, + } + ) except self.database_engine.module.IntegrityError: raise StoreError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE -- cgit 1.5.1 From 0da24cac8bde47961396f7da774d8dc8ed847107 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 6 Jul 2016 11:04:44 +0100 Subject: Add null separator to hmac --- scripts/register_new_matrix_user | 2 ++ synapse/rest/client/v1/register.py | 2 ++ 2 files changed, 4 insertions(+) (limited to 'synapse/rest/client/v1') diff --git a/scripts/register_new_matrix_user b/scripts/register_new_matrix_user index 987bf32d1c..12ed20d623 100755 --- a/scripts/register_new_matrix_user +++ b/scripts/register_new_matrix_user @@ -32,7 +32,9 @@ def request_registration(user, password, server_location, shared_secret, admin=F ) mac.update(user) + mac.update("\x00") mac.update(password) + mac.update("\x00") mac.update("admin" if admin else "notadmin") mac = mac.hexdigest() diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 25d63a0b0b..83872f5f60 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -336,7 +336,9 @@ class RegisterRestServlet(ClientV1RestServlet): digestmod=sha1, ) want_mac.update(user) + want_mac.update("\x00") want_mac.update(password) + want_mac.update("\x00") want_mac.update("admin" if admin else "notadmin") want_mac = want_mac.hexdigest() -- cgit 1.5.1 From 76b18df3d95cd881017a9aa5c8473409928faecd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 6 Jul 2016 11:16:10 +0100 Subject: Check that there are no null bytes in user and passsword --- synapse/rest/client/v1/register.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'synapse/rest/client/v1') diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 83872f5f60..ce7099b18f 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -327,6 +327,12 @@ class RegisterRestServlet(ClientV1RestServlet): password = register_json["password"].encode("utf-8") admin = register_json.get("admin", None) + # Its important to check as we use null bytes as HMAC field separators + if "\x00" in user: + raise SynapseError(400, "Invalid user") + if "\x00" in password: + raise SynapseError(400, "Invalid password") + # str() because otherwise hmac complains that 'unicode' does not # have the buffer interface got_mac = str(register_json["mac"]) -- cgit 1.5.1 From 67f2c901ea4196d869380c1c5cdd8569934857ed Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 6 Jul 2016 15:56:59 +0100 Subject: Add rest servlet. Fix SQL. --- synapse/rest/client/v1/admin.py | 1 + synapse/storage/events.py | 9 ++++----- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 71537a7d0b..b0cb31a448 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -124,3 +124,4 @@ def register_servlets(hs, http_server): WhoisRestServlet(hs).register(http_server) PurgeMediaCacheRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) + PurgeHistoryRestServlet(hs).register(http_server) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index c3b498bb3d..23ebd5d4c5 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1384,10 +1384,6 @@ class EventsStore(SQLBaseStore): (event_id,) for event_id, state_key in event_rows if state_key is None and not self.hs.is_mine_id(event_id) ] - to_not_delete = [ - (event_id,) for event_id, state_key in event_rows - if state_key is not None or self.hs.is_mine_id(event_id) - ] for table in ( "events", "event_json", @@ -1424,7 +1420,10 @@ class EventsStore(SQLBaseStore): txn.executemany( "UPDATE events SET outlier = ?" " WHERE event_id = ?", - to_not_delete + [ + (True, event_id,) for event_id, state_key in event_rows + if state_key is not None or self.hs.is_mine_id(event_id) + ] ) -- cgit 1.5.1 From 0136a522b18a734db69171d60566f501c0ced663 Mon Sep 17 00:00:00 2001 From: Negar Fazeli Date: Fri, 8 Jul 2016 16:53:18 +0200 Subject: Bug fix: expire invalid access tokens --- synapse/api/auth.py | 3 +++ synapse/handlers/auth.py | 5 +++-- synapse/handlers/register.py | 6 +++--- synapse/rest/client/v1/register.py | 2 +- tests/api/test_auth.py | 31 ++++++++++++++++++++++++++++++- tests/handlers/test_register.py | 4 ++-- 6 files changed, 42 insertions(+), 9 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index a4d658a9d0..521a52e001 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -629,7 +629,10 @@ class Auth(object): except AuthError: # TODO(daniel): Remove this fallback when all existing access tokens # have been re-issued as macaroons. + if self.hs.config.expire_access_token: + raise ret = yield self._look_up_user_by_access_token(token) + defer.returnValue(ret) @defer.inlineCallbacks diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index e259213a36..5a0ed9d6b9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -637,12 +637,13 @@ class AuthHandler(BaseHandler): yield self.store.add_refresh_token_to_user(user_id, refresh_token) defer.returnValue(refresh_token) - def generate_access_token(self, user_id, extra_caveats=None): + def generate_access_token(self, user_id, extra_caveats=None, + duration_in_ms=(60 * 60 * 1000)): extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") now = self.hs.get_clock().time_msec() - expiry = now + (60 * 60 * 1000) + expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 8c3381df8a..6b33b27149 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -360,7 +360,7 @@ class RegistrationHandler(BaseHandler): defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, localpart, displayname, duration_seconds, + def get_or_create_user(self, localpart, displayname, duration_in_ms, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -390,8 +390,8 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.auth_handler().generate_short_term_login_token( - user_id, duration_seconds) + token = self.auth_handler().generate_access_token( + user_id, None, duration_in_ms) if need_register: yield self.store.register( diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index ce7099b18f..8e1f1b7845 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -429,7 +429,7 @@ class CreateUserRestServlet(ClientV1RestServlet): user_id, token = yield handler.get_or_create_user( localpart=localpart, displayname=displayname, - duration_seconds=duration_seconds, + duration_in_ms=(duration_seconds * 1000), password_hash=password_hash ) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index ad269af0ec..960c23d631 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -281,7 +281,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user,)) - macaroon.add_first_party_caveat("time < 1") # ms + macaroon.add_first_party_caveat("time < -2000") # ms self.hs.clock.now = 5000 # seconds self.hs.config.expire_access_token = True @@ -293,3 +293,32 @@ class AuthTestCase(unittest.TestCase): yield self.auth.get_user_from_macaroon(macaroon.serialize()) self.assertEqual(401, cm.exception.code) self.assertIn("Invalid macaroon", cm.exception.msg) + + @defer.inlineCallbacks + def test_get_user_from_macaroon_with_valid_duration(self): + # TODO(danielwh): Remove this mock when we remove the + # get_user_by_access_token fallback. + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + self.store.get_user_by_access_token = Mock( + return_value={"name": "@baldrick:matrix.org"} + ) + + user_id = "@baldrick:matrix.org" + macaroon = pymacaroons.Macaroon( + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = access") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + macaroon.add_first_party_caveat("time < 900000000") # ms + + self.hs.clock.now = 5000 # seconds + self.hs.config.expire_access_token = True + + user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize()) + user = user_info["user"] + self.assertEqual(UserID.from_string(user_id), user) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 69a5e5b1d4..a7de3c7c17 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -42,12 +42,12 @@ class RegistrationTestCase(unittest.TestCase): http_client=None, expire_access_token=True) self.auth_handler = Mock( - generate_short_term_login_token=Mock(return_value='secret')) + generate_access_token=Mock(return_value='secret')) self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler self.hs.get_handlers().profile_handler = Mock() self.mock_handler = Mock(spec=[ - "generate_short_term_login_token", + "generate_access_token", ]) self.hs.get_auth_handler = Mock(return_value=self.auth_handler) -- cgit 1.5.1 From a98d2152049b0a61426ed3d8b6ac872a9ca3f535 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 14 Jul 2016 15:59:25 +0100 Subject: Add filter param to /messages API --- synapse/handlers/message.py | 16 ++++++++++++---- synapse/rest/client/v1/room.py | 11 ++++++++++- tests/storage/event_injector.py | 1 + tests/storage/test_events.py | 12 ++++++------ 4 files changed, 29 insertions(+), 11 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ad2753c1b5..dc76d34a52 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -66,7 +66,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, - as_client_event=True): + as_client_event=True, event_filter=None): """Get messages in a room. Args: @@ -75,11 +75,11 @@ class MessageHandler(BaseHandler): pagin_config (synapse.api.streams.PaginationConfig): The pagination config rules to apply, if any. as_client_event (bool): True to get events in client-server format. + event_filter (Filter): Filter to apply to results or None Returns: dict: Pagination API results """ user_id = requester.user.to_string() - data_source = self.hs.get_event_sources().sources["room"] if pagin_config.from_token: room_token = pagin_config.from_token.room_key @@ -129,8 +129,13 @@ class MessageHandler(BaseHandler): room_id, max_topo ) - events, next_key = yield data_source.get_pagination_rows( - requester.user, source_config, room_id + events, next_key = yield self.store.paginate_room_events( + room_id=room_id, + from_key=source_config.from_key, + to_key=source_config.to_key, + direction=source_config.direction, + limit=source_config.limit, + event_filter=event_filter, ) next_token = pagin_config.from_token.copy_and_replace( @@ -144,6 +149,9 @@ class MessageHandler(BaseHandler): "end": next_token.to_string(), }) + if event_filter: + events = event_filter.filter(events) + events = yield filter_events_for_client( self.store, user_id, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 86fbe2747d..866a1e9120 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -20,12 +20,14 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.api.errors import SynapseError, Codes, AuthError from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership +from synapse.api.filtering import Filter from synapse.types import UserID, RoomID, RoomAlias from synapse.events.utils import serialize_event from synapse.http.servlet import parse_json_object_from_request import logging import urllib +import ujson as json logger = logging.getLogger(__name__) @@ -327,12 +329,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet): request, default_limit=10, ) as_client_event = "raw" not in request.args + filter_bytes = request.args.get("filter", None) + if filter_bytes: + filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8") + event_filter = Filter(json.loads(filter_json)) + else: + event_filter = None handler = self.handlers.message_handler msgs = yield handler.get_messages( room_id=room_id, requester=requester, pagin_config=pagination_config, - as_client_event=as_client_event + as_client_event=as_client_event, + event_filter=event_filter, ) defer.returnValue((200, msgs)) diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py index f22ba8db89..38556da9a7 100644 --- a/tests/storage/event_injector.py +++ b/tests/storage/event_injector.py @@ -30,6 +30,7 @@ class EventInjector: def create_room(self, room): builder = self.event_builder_factory.new({ "type": EventTypes.Create, + "sender": "", "room_id": room.to_string(), "content": {}, }) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 18a6cff0c7..3762b38e37 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -37,7 +37,7 @@ class EventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_count_daily_messages(self): - self.db_pool.runQuery("DELETE FROM stats_reporting") + yield self.db_pool.runQuery("DELETE FROM stats_reporting") self.hs.clock.now = 100 @@ -60,7 +60,7 @@ class EventsStoreTestCase(unittest.TestCase): # it isn't old enough. count = yield self.store.count_daily_messages() self.assertIsNone(count) - self._assert_stats_reporting(1, self.hs.clock.now) + yield self._assert_stats_reporting(1, self.hs.clock.now) # Already reported yesterday, two new events from today. yield self.event_injector.inject_message(room, user, "Yeah they are!") @@ -68,21 +68,21 @@ class EventsStoreTestCase(unittest.TestCase): self.hs.clock.now += 60 * 60 * 24 count = yield self.store.count_daily_messages() self.assertEqual(2, count) # 2 since yesterday - self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever + yield self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever # Last reported too recently. yield self.event_injector.inject_message(room, user, "Who could disagree?") self.hs.clock.now += 60 * 60 * 22 count = yield self.store.count_daily_messages() self.assertIsNone(count) - self._assert_stats_reporting(4, self.hs.clock.now) + yield self._assert_stats_reporting(4, self.hs.clock.now) # Last reported too long ago yield self.event_injector.inject_message(room, user, "No one.") self.hs.clock.now += 60 * 60 * 26 count = yield self.store.count_daily_messages() self.assertIsNone(count) - self._assert_stats_reporting(5, self.hs.clock.now) + yield self._assert_stats_reporting(5, self.hs.clock.now) # And now let's actually report something yield self.event_injector.inject_message(room, user, "Indeed.") @@ -92,7 +92,7 @@ class EventsStoreTestCase(unittest.TestCase): self.hs.clock.now += (60 * 60 * 24) + 50 count = yield self.store.count_daily_messages() self.assertEqual(3, count) - self._assert_stats_reporting(8, self.hs.clock.now) + yield self._assert_stats_reporting(8, self.hs.clock.now) @defer.inlineCallbacks def _get_last_stream_token(self): -- cgit 1.5.1 From dcfd71aa4c4a1d3d71356fd2f5d854fb1db8fafa Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 Jul 2016 12:34:23 +0100 Subject: Refactor login flow Make sure that we have the canonical user_id *before* calling get_login_tuple_for_user_id. Replace login_with_password with a method which just validates the password, and have the caller call get_login_tuple_for_user_id. This brings the password flow into line with the other flows, and will give us a place to register the device_id if necessary. --- synapse/handlers/auth.py | 106 ++++++++++++++++++++++------------------ synapse/rest/client/v1/login.py | 41 +++++++++------- 2 files changed, 82 insertions(+), 65 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 5a0ed9d6b9..983994fa95 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -230,7 +230,6 @@ class AuthHandler(BaseHandler): sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) - @defer.inlineCallbacks def _check_password_auth(self, authdict, _): if "user" not in authdict or "password" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) @@ -240,11 +239,7 @@ class AuthHandler(BaseHandler): if not user_id.startswith('@'): user_id = UserID.create(user_id, self.hs.hostname).to_string() - if not (yield self._check_password(user_id, password)): - logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - defer.returnValue(user_id) + return self._check_password(user_id, password) @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip): @@ -348,67 +343,66 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - @defer.inlineCallbacks - def login_with_password(self, user_id, password): + def validate_password_login(self, user_id, password): """ Authenticates the user with their username and password. Used only by the v1 login API. Args: - user_id (str): User ID + user_id (str): complete @user:id password (str): Password Returns: - A tuple of: - The user's ID. - The access token for the user's session. - The refresh token for the user's session. + defer.Deferred: (str) canonical user id Raises: - StoreError if there was a problem storing the token. + StoreError if there was a problem accessing the database LoginError if there was an authentication problem. """ - - if not (yield self._check_password(user_id, password)): - logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - logger.info("Logging in user %s", user_id) - access_token = yield self.issue_access_token(user_id) - refresh_token = yield self.issue_refresh_token(user_id) - defer.returnValue((user_id, access_token, refresh_token)) + return self._check_password(user_id, password) @defer.inlineCallbacks def get_login_tuple_for_user_id(self, user_id): """ Gets login tuple for the user with the given user ID. + + Creates a new access/refresh token for the user. + The user is assumed to have been authenticated by some other - machanism (e.g. CAS) + machanism (e.g. CAS), and the user_id converted to the canonical case. Args: - user_id (str): User ID + user_id (str): canonical User ID Returns: A tuple of: - The user's ID. The access token for the user's session. The refresh token for the user's session. Raises: StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id) - logger.info("Logging in user %s", user_id) access_token = yield self.issue_access_token(user_id) refresh_token = yield self.issue_refresh_token(user_id) - defer.returnValue((user_id, access_token, refresh_token)) + defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks - def does_user_exist(self, user_id): + def check_user_exists(self, user_id): + """ + Checks to see if a user with the given id exists. Will check case + insensitively, but return None if there are multiple inexact matches. + + Args: + (str) user_id: complete @user:id + + Returns: + defer.Deferred: (str) canonical_user_id, or None if zero or + multiple matches + """ try: - yield self._find_user_id_and_pwd_hash(user_id) - defer.returnValue(True) + res = yield self._find_user_id_and_pwd_hash(user_id) + defer.returnValue(res[0]) except LoginError: - defer.returnValue(False) + defer.returnValue(None) @defer.inlineCallbacks def _find_user_id_and_pwd_hash(self, user_id): @@ -438,27 +432,45 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def _check_password(self, user_id, password): - """ + """Authenticate a user against the LDAP and local databases. + + user_id is checked case insensitively against the local database, but + will throw if there are multiple inexact matches. + + Args: + user_id (str): complete @user:id Returns: - True if the user_id successfully authenticated + (str) the canonical_user_id + Raises: + LoginError if the password was incorrect """ valid_ldap = yield self._check_ldap_password(user_id, password) if valid_ldap: - defer.returnValue(True) + defer.returnValue(user_id) - valid_local_password = yield self._check_local_password(user_id, password) - if valid_local_password: - defer.returnValue(True) - - defer.returnValue(False) + result = yield self._check_local_password(user_id, password) + defer.returnValue(result) @defer.inlineCallbacks def _check_local_password(self, user_id, password): - try: - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - defer.returnValue(self.validate_hash(password, password_hash)) - except LoginError: - defer.returnValue(False) + """Authenticate a user against the local password database. + + user_id is checked case insensitively, but will throw if there are + multiple inexact matches. + + Args: + user_id (str): complete @user:id + Returns: + (str) the canonical_user_id + Raises: + LoginError if the password was incorrect + """ + user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) + result = self.validate_hash(password, password_hash) + if not result: + logger.warn("Failed password login for user %s", user_id) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + defer.returnValue(user_id) @defer.inlineCallbacks def _check_ldap_password(self, user_id, password): @@ -570,7 +582,7 @@ class AuthHandler(BaseHandler): ) # check for existing account, if none exists, create one - if not (yield self.does_user_exist(user_id)): + if not (yield self.check_user_exists(user_id)): # query user metadata for account creation query = "({prop}={value})".format( prop=self.ldap_attributes['uid'], diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 8df9d10efa..a1f2ba8773 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -145,10 +145,13 @@ class LoginRestServlet(ClientV1RestServlet): ).to_string() auth_handler = self.auth_handler - user_id, access_token, refresh_token = yield auth_handler.login_with_password( + user_id = yield auth_handler.validate_password_login( user_id=user_id, - password=login_submission["password"]) - + password=login_submission["password"], + ) + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id(user_id) + ) result = { "user_id": user_id, # may have changed "access_token": access_token, @@ -165,7 +168,7 @@ class LoginRestServlet(ClientV1RestServlet): user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) - user_id, access_token, refresh_token = ( + access_token, refresh_token = ( yield auth_handler.get_login_tuple_for_user_id(user_id) ) result = { @@ -196,13 +199,15 @@ class LoginRestServlet(ClientV1RestServlet): user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + registered_user_id = yield auth_handler.check_user_exists(user_id) + if registered_user_id: + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id( + registered_user_id + ) ) result = { - "user_id": user_id, # may have changed + "user_id": registered_user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, @@ -245,13 +250,13 @@ class LoginRestServlet(ClientV1RestServlet): user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + registered_user_id = yield auth_handler.check_user_exists(user_id) + if registered_user_id: + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id(registered_user_id) ) result = { - "user_id": user_id, # may have changed + "user_id": registered_user_id, "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, @@ -414,13 +419,13 @@ class CasTicketServlet(ClientV1RestServlet): user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if not user_exists: - user_id, _ = ( + registered_user_id = yield auth_handler.check_user_exists(user_id) + if not registered_user_id: + registered_user_id, _ = ( yield self.handlers.registration_handler.register(localpart=user) ) - login_token = auth_handler.generate_short_term_login_token(user_id) + login_token = auth_handler.generate_short_term_login_token(registered_user_id) redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, login_token) request.redirect(redirect_url) -- cgit 1.5.1 From f863a52ceacf69ab19b073383be80603a2f51c0a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 Jul 2016 13:19:07 +0100 Subject: Add device_id support to /login Add a 'devices' table to the storage, as well as a 'device_id' column to refresh_tokens. Allow the client to pass a device_id, and initial_device_display_name, to /login. If login is successful, then register the device in the devices table if it wasn't known already. If no device_id was supplied, make one up. Associate the device_id with the access token and refresh token, so that we can get at it again later. Ensure that the device_id is copied from the refresh token to the access_token when the token is refreshed. --- synapse/handlers/auth.py | 19 +++--- synapse/handlers/device.py | 71 ++++++++++++++++++++ synapse/rest/client/v1/login.py | 39 ++++++++++- synapse/rest/client/v2_alpha/tokenrefresh.py | 10 ++- synapse/server.py | 5 ++ synapse/storage/__init__.py | 3 + synapse/storage/devices.py | 77 ++++++++++++++++++++++ synapse/storage/registration.py | 28 +++++--- synapse/storage/schema/delta/33/devices.sql | 21 ++++++ .../schema/delta/33/refreshtoken_device.sql | 16 +++++ tests/handlers/test_device.py | 75 +++++++++++++++++++++ tests/storage/test_registration.py | 21 ++++-- 12 files changed, 354 insertions(+), 31 deletions(-) create mode 100644 synapse/handlers/device.py create mode 100644 synapse/storage/devices.py create mode 100644 synapse/storage/schema/delta/33/devices.sql create mode 100644 synapse/storage/schema/delta/33/refreshtoken_device.sql create mode 100644 tests/handlers/test_device.py (limited to 'synapse/rest/client/v1') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 983994fa95..ce9bc18849 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -361,7 +361,7 @@ class AuthHandler(BaseHandler): return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id): + def get_login_tuple_for_user_id(self, user_id, device_id=None): """ Gets login tuple for the user with the given user ID. @@ -372,6 +372,7 @@ class AuthHandler(BaseHandler): Args: user_id (str): canonical User ID + device_id (str): the device ID to associate with the access token Returns: A tuple of: The access token for the user's session. @@ -380,9 +381,9 @@ class AuthHandler(BaseHandler): StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - logger.info("Logging in user %s", user_id) - access_token = yield self.issue_access_token(user_id) - refresh_token = yield self.issue_refresh_token(user_id) + logger.info("Logging in user %s on device %s", user_id, device_id) + access_token = yield self.issue_access_token(user_id, device_id) + refresh_token = yield self.issue_refresh_token(user_id, device_id) defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks @@ -638,15 +639,17 @@ class AuthHandler(BaseHandler): defer.returnValue(False) @defer.inlineCallbacks - def issue_access_token(self, user_id): + def issue_access_token(self, user_id, device_id=None): access_token = self.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token) + yield self.store.add_access_token_to_user(user_id, access_token, + device_id) defer.returnValue(access_token) @defer.inlineCallbacks - def issue_refresh_token(self, user_id): + def issue_refresh_token(self, user_id, device_id=None): refresh_token = self.generate_refresh_token(user_id) - yield self.store.add_refresh_token_to_user(user_id, refresh_token) + yield self.store.add_refresh_token_to_user(user_id, refresh_token, + device_id) defer.returnValue(refresh_token) def generate_access_token(self, user_id, extra_caveats=None, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py new file mode 100644 index 0000000000..8d7d9874f8 --- /dev/null +++ b/synapse/handlers/device.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.errors import StoreError +from synapse.util import stringutils +from twisted.internet import defer +from ._base import BaseHandler + +import logging + +logger = logging.getLogger(__name__) + + +class DeviceHandler(BaseHandler): + def __init__(self, hs): + super(DeviceHandler, self).__init__(hs) + + @defer.inlineCallbacks + def check_device_registered(self, user_id, device_id, + initial_device_display_name): + """ + If the given device has not been registered, register it with the + supplied display name. + + If no device_id is supplied, we make one up. + + Args: + user_id (str): @user:id + device_id (str | None): device id supplied by client + initial_device_display_name (str | None): device display name from + client + Returns: + str: device id (generated if none was supplied) + """ + if device_id is not None: + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=True, + ) + defer.returnValue(device_id) + + # if the device id is not specified, we'll autogen one, but loop a few + # times in case of a clash. + attempts = 0 + while attempts < 5: + try: + device_id = stringutils.random_string_with_symbols(16) + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=False, + ) + defer.returnValue(device_id) + except StoreError: + attempts += 1 + + raise StoreError(500, "Couldn't generate a device ID.") diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a1f2ba8773..e8b791519c 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet): self.servername = hs.config.server_name self.http_client = hs.get_simple_http_client() self.auth_handler = self.hs.get_auth_handler() + self.device_handler = self.hs.get_device_handler() def on_GET(self, request): flows = [] @@ -149,14 +150,16 @@ class LoginRestServlet(ClientV1RestServlet): user_id=user_id, password=login_submission["password"], ) + device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -168,14 +171,16 @@ class LoginRestServlet(ClientV1RestServlet): user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) + device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -252,8 +257,13 @@ class LoginRestServlet(ClientV1RestServlet): auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if registered_user_id: + device_id = yield self._register_device( + registered_user_id, login_submission + ) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(registered_user_id) + yield auth_handler.get_login_tuple_for_user_id( + registered_user_id, device_id + ) ) result = { "user_id": registered_user_id, @@ -262,6 +272,9 @@ class LoginRestServlet(ClientV1RestServlet): "home_server": self.hs.hostname, } else: + # TODO: we should probably check that the register isn't going + # to fonx/change our user_id before registering the device + device_id = yield self._register_device(user_id, login_submission) user_id, access_token = ( yield self.handlers.registration_handler.register(localpart=user) ) @@ -300,6 +313,26 @@ class LoginRestServlet(ClientV1RestServlet): return (user, attributes) + def _register_device(self, user_id, login_submission): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) login_submission: dictionary supplied to /login call, from + which we pull device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get( + "initial_device_display_name") + return self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + class SAML2RestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login/saml2", releases=()) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 8270e8787f..0d312c91d4 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -39,9 +39,13 @@ class TokenRefreshRestServlet(RestServlet): try: old_refresh_token = body["refresh_token"] auth_handler = self.hs.get_auth_handler() - (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( - old_refresh_token, auth_handler.generate_refresh_token) - new_access_token = yield auth_handler.issue_access_token(user_id) + refresh_result = yield self.store.exchange_refresh_token( + old_refresh_token, auth_handler.generate_refresh_token + ) + (user_id, new_refresh_token, device_id) = refresh_result + new_access_token = yield auth_handler.issue_access_token( + user_id, device_id + ) defer.returnValue((200, { "access_token": new_access_token, "refresh_token": new_refresh_token, diff --git a/synapse/server.py b/synapse/server.py index d49a1a8a96..e8b166990d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -25,6 +25,7 @@ from twisted.enterprise import adbapi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.appservice.api import ApplicationServiceApi from synapse.federation import initialize_http_replication +from synapse.handlers.device import DeviceHandler from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.notifier import Notifier from synapse.api.auth import Auth @@ -92,6 +93,7 @@ class HomeServer(object): 'typing_handler', 'room_list_handler', 'auth_handler', + 'device_handler', 'application_service_api', 'application_service_scheduler', 'application_service_handler', @@ -197,6 +199,9 @@ class HomeServer(object): def build_auth_handler(self): return AuthHandler(self) + def build_device_handler(self): + return DeviceHandler(self) + def build_application_service_api(self): return ApplicationServiceApi(self) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 1c93e18f9d..73fb334dd6 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -14,6 +14,8 @@ # limitations under the License. from twisted.internet import defer + +from synapse.storage.devices import DeviceStore from .appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore ) @@ -80,6 +82,7 @@ class DataStore(RoomMemberStore, RoomStore, EventPushActionsStore, OpenIdStore, ClientIpStore, + DeviceStore, ): def __init__(self, db_conn, hs): diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py new file mode 100644 index 0000000000..9065e96d28 --- /dev/null +++ b/synapse/storage/devices.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from ._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class DeviceStore(SQLBaseStore): + @defer.inlineCallbacks + def store_device(self, user_id, device_id, + initial_device_display_name, + ignore_if_known=True): + """Ensure the given device is known; add it to the store if not + + Args: + user_id (str): id of user associated with the device + device_id (str): id of device + initial_device_display_name (str): initial displayname of the + device + ignore_if_known (bool): ignore integrity errors which mean the + device is already known + Returns: + defer.Deferred + Raises: + StoreError: if ignore_if_known is False and the device was already + known + """ + try: + yield self._simple_insert( + "devices", + values={ + "user_id": user_id, + "device_id": device_id, + "display_name": initial_device_display_name + }, + desc="store_device", + or_ignore=ignore_if_known, + ) + except Exception as e: + logger.error("store_device with device_id=%s failed: %s", + device_id, e) + raise StoreError(500, "Problem storing device.") + + def get_device(self, user_id, device_id): + """Retrieve a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred for a namedtuple containing the device information + Raises: + StoreError: if the device is not found + """ + return self._simple_select_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcols=("user_id", "device_id", "display_name"), + desc="get_device", + ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index d957a629dc..26ef1cfd8a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -31,12 +31,14 @@ class RegistrationStore(SQLBaseStore): self.clock = hs.get_clock() @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token): + def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. Args: user_id (str): The user ID. token (str): The new access token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -47,18 +49,21 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_access_token_to_user", ) @defer.inlineCallbacks - def add_refresh_token_to_user(self, user_id, token): + def add_refresh_token_to_user(self, user_id, token, device_id=None): """Adds a refresh token for the given user. Args: user_id (str): The user ID. token (str): The new refresh token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -69,7 +74,8 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_refresh_token_to_user", ) @@ -291,18 +297,18 @@ class RegistrationStore(SQLBaseStore): ) def exchange_refresh_token(self, refresh_token, token_generator): - """Exchange a refresh token for a new access token and refresh token. + """Exchange a refresh token for a new one. Doing so invalidates the old refresh token - refresh tokens are single use. Args: - token (str): The refresh token of a user. + refresh_token (str): The refresh token of a user. token_generator (fn: str -> str): Function which, when given a user ID, returns a unique refresh token for that user. This function must never return the same value twice. Returns: - tuple of (user_id, refresh_token) + tuple of (user_id, new_refresh_token, device_id) Raises: StoreError if no user was found with that refresh token. """ @@ -314,12 +320,13 @@ class RegistrationStore(SQLBaseStore): ) def _exchange_refresh_token(self, txn, old_token, token_generator): - sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" + sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?" txn.execute(sql, (old_token,)) rows = self.cursor_to_dict(txn) if not rows: raise StoreError(403, "Did not recognize refresh token") user_id = rows[0]["user_id"] + device_id = rows[0]["device_id"] # TODO(danielwh): Maybe perform a validation on the macaroon that # macaroon.user_id == user_id. @@ -328,7 +335,7 @@ class RegistrationStore(SQLBaseStore): sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" txn.execute(sql, (new_token, old_token,)) - return user_id, new_token + return user_id, new_token, device_id @defer.inlineCallbacks def is_server_admin(self, user): @@ -356,7 +363,8 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id" + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/schema/delta/33/devices.sql new file mode 100644 index 0000000000..eca7268d82 --- /dev/null +++ b/synapse/storage/schema/delta/33/devices.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE devices ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + display_name TEXT, + CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) +); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/33/refreshtoken_device.sql new file mode 100644 index 0000000000..b21da00dde --- /dev/null +++ b/synapse/storage/schema/delta/33/refreshtoken_device.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE refresh_tokens ADD COLUMN device_id BIGINT; diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py new file mode 100644 index 0000000000..cc6512ccc7 --- /dev/null +++ b/tests/handlers/test_device.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.handlers.device import DeviceHandler +from tests import unittest +from tests.utils import setup_test_homeserver + + +class DeviceHandlers(object): + def __init__(self, hs): + self.device_handler = DeviceHandler(hs) + + +class DeviceTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver(handlers=None) + self.hs.handlers = handlers = DeviceHandlers(self.hs) + self.handler = handlers.device_handler + + @defer.inlineCallbacks + def test_device_is_created_if_doesnt_exist(self): + res = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="display name" + ) + self.assertEqual(res, "fco") + + dev = yield self.handler.store.get_device("boris", "fco") + self.assertEqual(dev["display_name"], "display name") + + @defer.inlineCallbacks + def test_device_is_preserved_if_exists(self): + res1 = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="display name" + ) + self.assertEqual(res1, "fco") + + res2 = yield self.handler.check_device_registered( + user_id="boris", + device_id="fco", + initial_device_display_name="new display name" + ) + self.assertEqual(res2, "fco") + + dev = yield self.handler.store.get_device("boris", "fco") + self.assertEqual(dev["display_name"], "display name") + + @defer.inlineCallbacks + def test_device_id_is_made_up_if_unspecified(self): + device_id = yield self.handler.check_device_registered( + user_id="theresa", + device_id=None, + initial_device_display_name="display" + ) + + dev = yield self.handler.store.get_device("theresa", device_id) + self.assertEqual(dev["display_name"], "display") diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index b8384c98d8..b03ca303a2 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -38,6 +38,7 @@ class RegistrationStoreTestCase(unittest.TestCase): "BcDeFgHiJkLmNoPqRsTuVwXyZa" ] self.pwhash = "{xx1}123456789" + self.device_id = "akgjhdjklgshg" @defer.inlineCallbacks def test_register(self): @@ -64,13 +65,15 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_add_tokens(self): yield self.store.register(self.user_id, self.tokens[0], self.pwhash) - yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) + yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], + self.device_id) result = yield self.store.get_user_by_access_token(self.tokens[1]) self.assertDictContainsSubset( { "name": self.user_id, + "device_id": self.device_id, }, result ) @@ -80,20 +83,24 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_exchange_refresh_token_valid(self): uid = stringutils.random_string(32) + device_id = stringutils.random_string(16) generator = TokenGenerator() last_token = generator.generate(uid) self.db_pool.runQuery( - "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", - (uid, last_token,)) + "INSERT INTO refresh_tokens(user_id, token, device_id) " + "VALUES(?,?,?)", + (uid, last_token, device_id)) - (found_user_id, refresh_token) = yield self.store.exchange_refresh_token( - last_token, generator.generate) + (found_user_id, refresh_token, device_id) = \ + yield self.store.exchange_refresh_token(last_token, + generator.generate) self.assertEqual(uid, found_user_id) rows = yield self.db_pool.runQuery( - "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, )) - self.assertEqual([(refresh_token,)], rows) + "SELECT token, device_id FROM refresh_tokens WHERE user_id = ?", + (uid, )) + self.assertEqual([(refresh_token, device_id)], rows) # We issued token 1, then exchanged it for token 2 expected_refresh_token = u"%s-%d" % (uid, 2,) self.assertEqual(expected_refresh_token, refresh_token) -- cgit 1.5.1 From 40cbffb2d2ca0166f1377ac4ec5988046ea4ca10 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 19 Jul 2016 18:46:19 +0100 Subject: Further registration refactoring * `RegistrationHandler.appservice_register` no longer issues an access token: instead it is left for the caller to do it. (There are two of these, one in `synapse/rest/client/v1/register.py`, which now simply calls `AuthHandler.issue_access_token`, and the other in `synapse/rest/client/v2_alpha/register.py`, which is covered below). * In `synapse/rest/client/v2_alpha/register.py`, move the generation of access_tokens into `_create_registration_details`. This means that the normal flow no longer needs to call `AuthHandler.issue_access_token`; the shared-secret flow can tell `RegistrationHandler.register` not to generate a token; and the appservice flow continues to work despite the above change. --- synapse/handlers/register.py | 13 +++++--- synapse/rest/client/v1/register.py | 4 ++- synapse/rest/client/v2_alpha/register.py | 50 +++++++++++++++++++++-------- synapse/storage/registration.py | 6 ++-- tests/rest/client/v2_alpha/test_register.py | 6 +++- 5 files changed, 57 insertions(+), 22 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 6b33b27149..94b19d0cb0 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -99,8 +99,13 @@ class RegistrationHandler(BaseHandler): localpart : The local part of the user ID to register. If None, one will be generated. password (str) : The password to assign to this user so they can - login again. This can be None which means they cannot login again - via a password (e.g. the user is an application service user). + login again. This can be None which means they cannot login again + via a password (e.g. the user is an application service user). + generate_token (bool): Whether a new access token should be + generated. Having this be True should be considered deprecated, + since it offers no means of associating a device_id with the + access_token. Instead you should call auth_handler.issue_access_token + after registration. Returns: A tuple of (user_id, access_token). Raises: @@ -196,15 +201,13 @@ class RegistrationHandler(BaseHandler): user_id, allowed_appservice=service ) - token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, - token=token, password_hash="", appservice_id=service_id, create_profile_with_localpart=user.localpart, ) - defer.returnValue((user_id, token)) + defer.returnValue(user_id) @defer.inlineCallbacks def check_recaptcha(self, ip, private_key, challenge, response): diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 8e1f1b7845..28b59952c3 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -60,6 +60,7 @@ class RegisterRestServlet(ClientV1RestServlet): # TODO: persistent storage self.sessions = {} self.enable_registration = hs.config.enable_registration + self.auth_handler = hs.get_auth_handler() def on_GET(self, request): if self.hs.config.enable_registration_captcha: @@ -299,9 +300,10 @@ class RegisterRestServlet(ClientV1RestServlet): user_localpart = register_json["user"].encode("utf-8") handler = self.handlers.registration_handler - (user_id, token) = yield handler.appservice_register( + user_id = yield handler.appservice_register( user_localpart, as_token ) + token = yield self.auth_handler.issue_access_token(user_id) self._remove_session(session) defer.returnValue({ "user_id": user_id, diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 5db953a1e3..04004cfbbd 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -226,19 +226,17 @@ class RegisterRestServlet(RestServlet): add_email = True - access_token = yield self.auth_handler.issue_access_token( + result = yield self._create_registration_details( registered_user_id ) if add_email and result and LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] yield self._register_email_threepid( - registered_user_id, threepid, access_token, + registered_user_id, threepid, result["access_token"], params.get("bind_email") ) - result = yield self._create_registration_details(registered_user_id, - access_token) defer.returnValue((200, result)) def on_OPTIONS(self, _): @@ -246,10 +244,10 @@ class RegisterRestServlet(RestServlet): @defer.inlineCallbacks def _do_appservice_registration(self, username, as_token): - (user_id, token) = yield self.registration_handler.appservice_register( + user_id = yield self.registration_handler.appservice_register( username, as_token ) - defer.returnValue((yield self._create_registration_details(user_id, token))) + defer.returnValue((yield self._create_registration_details(user_id))) @defer.inlineCallbacks def _do_shared_secret_registration(self, username, password, mac): @@ -273,10 +271,12 @@ class RegisterRestServlet(RestServlet): 403, "HMAC incorrect", ) - (user_id, token) = yield self.registration_handler.register( - localpart=username, password=password + (user_id, _) = yield self.registration_handler.register( + localpart=username, password=password, generate_token=False, ) - defer.returnValue((yield self._create_registration_details(user_id, token))) + + result = yield self._create_registration_details(user_id) + defer.returnValue(result) @defer.inlineCallbacks def _register_email_threepid(self, user_id, threepid, token, bind_email): @@ -349,11 +349,31 @@ class RegisterRestServlet(RestServlet): defer.returnValue() @defer.inlineCallbacks - def _create_registration_details(self, user_id, token): - refresh_token = yield self.auth_handler.issue_refresh_token(user_id) + def _create_registration_details(self, user_id): + """Complete registration of newly-registered user + + Issues access_token and refresh_token, and builds the success response + body. + + Args: + (str) user_id: full canonical @user:id + + + Returns: + defer.Deferred: (object) dictionary for response from /register + """ + + access_token = yield self.auth_handler.issue_access_token( + user_id + ) + + refresh_token = yield self.auth_handler.issue_refresh_token( + user_id + ) + defer.returnValue({ "user_id": user_id, - "access_token": token, + "access_token": access_token, "home_server": self.hs.hostname, "refresh_token": refresh_token, }) @@ -366,7 +386,11 @@ class RegisterRestServlet(RestServlet): generate_token=False, make_guest=True ) - access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"]) + access_token = self.auth_handler.generate_access_token( + user_id, ["guest = true"] + ) + # XXX the "guest" caveat is not copied by /tokenrefresh. That's ok + # so long as we don't return a refresh_token here. defer.returnValue((200, { "user_id": user_id, "access_token": access_token, diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 26ef1cfd8a..9a92b35361 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -81,14 +81,16 @@ class RegistrationStore(SQLBaseStore): ) @defer.inlineCallbacks - def register(self, user_id, token, password_hash, + def register(self, user_id, token=None, password_hash=None, was_guest=False, make_guest=False, appservice_id=None, create_profile_with_localpart=None, admin=False): """Attempts to register an account. Args: user_id (str): The desired user ID to register. - token (str): The desired access token to use for this user. + token (str): The desired access token to use for this user. If this + is not None, the given access token is associated with the user + id. password_hash (str): Optional. The password hash for this user. was_guest (bool): Optional. Whether this is a guest account being upgraded to a non-guest account. diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 9a4215fef7..ccbb8776d3 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -61,8 +61,10 @@ class RegisterRestServletTestCase(unittest.TestCase): "id": "1234" } self.registration_handler.appservice_register = Mock( - return_value=(user_id, token) + return_value=user_id ) + self.auth_handler.issue_access_token = Mock(return_value=token) + (code, result) = yield self.servlet.on_POST(self.request) self.assertEquals(code, 200) det_data = { @@ -126,6 +128,8 @@ class RegisterRestServletTestCase(unittest.TestCase): } self.assertDictContainsSubset(det_data, result) self.assertIn("refresh_token", result) + self.auth_handler.issue_access_token.assert_called_once_with( + user_id) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False -- cgit 1.5.1 From 3413f1e284593aa63723cdcd52f443d63771ef62 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 19 Jul 2016 10:21:42 +0100 Subject: Type annotations Add some type annotations to help PyCharm (in particular) to figure out the types of a bunch of things. --- synapse/handlers/_base.py | 4 ++++ synapse/handlers/auth.py | 4 ++++ synapse/rest/client/v1/base.py | 4 ++++ synapse/rest/client/v1/register.py | 4 ++++ synapse/rest/client/v2_alpha/register.py | 9 +++++++++ synapse/server.pyi | 21 +++++++++++++++++++++ 6 files changed, 46 insertions(+) create mode 100644 synapse/server.pyi (limited to 'synapse/rest/client/v1') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d00685c389..6264aa0d9a 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -36,6 +36,10 @@ class BaseHandler(object): """ def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ self.store = hs.get_datastore() self.auth = hs.get_auth() self.notifier = hs.get_notifier() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index ce9bc18849..8f83923ddb 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -45,6 +45,10 @@ class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ super(AuthHandler, self).__init__(hs) self.checkers = { LoginType.PASSWORD: self._check_password_auth, diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 1c020b7e2c..96b49b01f2 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet): """ def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ self.hs = hs self.handlers = hs.get_handlers() self.builder_factory = hs.get_event_builder_factory() diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 8e1f1b7845..efe796c65f 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -52,6 +52,10 @@ class RegisterRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False) def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(RegisterRestServlet, self).__init__(hs) # sessions are stored as: # self.sessions = { diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 5db953a1e3..2722a58e3e 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -45,6 +45,10 @@ class RegisterRequestTokenRestServlet(RestServlet): PATTERNS = client_v2_patterns("/register/email/requestToken$") def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(RegisterRequestTokenRestServlet, self).__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler @@ -77,7 +81,12 @@ class RegisterRestServlet(RestServlet): PATTERNS = client_v2_patterns("/register$") def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(RegisterRestServlet, self).__init__() + self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() diff --git a/synapse/server.pyi b/synapse/server.pyi new file mode 100644 index 0000000000..902f725c06 --- /dev/null +++ b/synapse/server.pyi @@ -0,0 +1,21 @@ +import synapse.handlers +import synapse.handlers.auth +import synapse.handlers.device +import synapse.storage +import synapse.state + +class HomeServer(object): + def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: + pass + + def get_datastore(self) -> synapse.storage.DataStore: + pass + + def get_device_handler(self) -> synapse.handlers.device.DeviceHandler: + pass + + def get_handlers(self) -> synapse.handlers.Handlers: + pass + + def get_state_handler(self) -> synapse.state.StateHandler: + pass -- cgit 1.5.1 From 436bffd15fb8382a0d2dddd3c6f7a077ba751da2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 22 Jul 2016 14:52:53 +0100 Subject: Implement deleting devices --- synapse/handlers/auth.py | 22 ++++++++++++++++-- synapse/handlers/device.py | 27 +++++++++++++++++++++- synapse/rest/client/v1/login.py | 13 ++++++++--- synapse/rest/client/v2_alpha/devices.py | 14 +++++++++++ synapse/rest/client/v2_alpha/register.py | 10 ++++---- synapse/storage/devices.py | 15 ++++++++++++ synapse/storage/registration.py | 26 +++++++++++++++++---- .../schema/delta/33/access_tokens_device_index.sql | 17 ++++++++++++++ .../schema/delta/33/refreshtoken_device_index.sql | 17 ++++++++++++++ tests/handlers/test_device.py | 22 ++++++++++++++++-- tests/rest/client/v2_alpha/test_register.py | 14 +++++++---- 11 files changed, 176 insertions(+), 21 deletions(-) create mode 100644 synapse/storage/schema/delta/33/access_tokens_device_index.sql create mode 100644 synapse/storage/schema/delta/33/refreshtoken_device_index.sql (limited to 'synapse/rest/client/v1') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d5d2072436..2e138f328f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -77,6 +77,7 @@ class AuthHandler(BaseHandler): self.ldap_bind_password = hs.config.ldap_bind_password self.hs = hs # FIXME better possibility to access registrationHandler later? + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -374,7 +375,8 @@ class AuthHandler(BaseHandler): return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id, device_id=None): + def get_login_tuple_for_user_id(self, user_id, device_id=None, + initial_display_name=None): """ Gets login tuple for the user with the given user ID. @@ -383,9 +385,15 @@ class AuthHandler(BaseHandler): The user is assumed to have been authenticated by some other machanism (e.g. CAS), and the user_id converted to the canonical case. + The device will be recorded in the table if it is not there already. + Args: user_id (str): canonical User ID - device_id (str): the device ID to associate with the access token + device_id (str|None): the device ID to associate with the tokens. + None to leave the tokens unassociated with a device (deprecated: + we should always have a device ID) + initial_display_name (str): display name to associate with the + device if it needs re-registering Returns: A tuple of: The access token for the user's session. @@ -397,6 +405,16 @@ class AuthHandler(BaseHandler): logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) refresh_token = yield self.issue_refresh_token(user_id, device_id) + + # the device *should* have been registered before we got here; however, + # it's possible we raced against a DELETE operation. The thing we + # really don't want is active access_tokens without a record of the + # device, so we double-check it here. + if device_id is not None: + yield self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 1f9e15c33c..a7a192e1c9 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -100,7 +100,7 @@ class DeviceHandler(BaseHandler): Args: user_id (str): - device_id (str) + device_id (str): Returns: defer.Deferred: dict[str, X]: info on the device @@ -117,6 +117,31 @@ class DeviceHandler(BaseHandler): _update_device_from_client_ips(device, ips) defer.returnValue(device) + @defer.inlineCallbacks + def delete_device(self, user_id, device_id): + """ Delete the given device + + Args: + user_id (str): + device_id (str): + + Returns: + defer.Deferred: + """ + + try: + yield self.store.delete_device(user_id, device_id) + except errors.StoreError, e: + if e.code == 404: + # no match + pass + else: + raise + + yield self.store.user_delete_access_tokens(user_id, + device_id=device_id) + + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index e8b791519c..92fcae674a 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -152,7 +152,10 @@ class LoginRestServlet(ClientV1RestServlet): ) device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { "user_id": user_id, # may have changed @@ -173,7 +176,10 @@ class LoginRestServlet(ClientV1RestServlet): ) device_id = yield self._register_device(user_id, login_submission) access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { "user_id": user_id, # may have changed @@ -262,7 +268,8 @@ class LoginRestServlet(ClientV1RestServlet): ) access_token, refresh_token = ( yield auth_handler.get_login_tuple_for_user_id( - registered_user_id, device_id + registered_user_id, device_id, + login_submission.get("initial_device_display_name") ) ) result = { diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 8b9ab4f674..30ef8b3da9 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -70,6 +70,20 @@ class DeviceRestServlet(RestServlet): ) defer.returnValue((200, device)) + @defer.inlineCallbacks + def on_DELETE(self, request, device_id): + # XXX: it's not completely obvious we want to expose this endpoint. + # It allows the client to delete access tokens, which feels like a + # thing which merits extra auth. But if we want to do the interactive- + # auth dance, we should really make it possible to delete more than one + # device at a time. + requester = yield self.auth.get_user_by_req(request) + yield self.device_handler.delete_device( + requester.user.to_string(), + device_id, + ) + defer.returnValue((200, {})) + def register_servlets(hs, http_server): DevicesRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index c8c9395fc6..9f599ea8bb 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -374,13 +374,13 @@ class RegisterRestServlet(RestServlet): """ device_id = yield self._register_device(user_id, params) - access_token = yield self.auth_handler.issue_access_token( - user_id, device_id=device_id + access_token, refresh_token = ( + yield self.auth_handler.get_login_tuple_for_user_id( + user_id, device_id=device_id, + initial_display_name=params.get("initial_device_display_name") + ) ) - refresh_token = yield self.auth_handler.issue_refresh_token( - user_id, device_id=device_id - ) defer.returnValue({ "user_id": user_id, "access_token": access_token, diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 1cc6e07f2b..4689980f80 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -76,6 +76,21 @@ class DeviceStore(SQLBaseStore): desc="get_device", ) + def delete_device(self, user_id, device_id): + """Delete a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to retrieve + Returns: + defer.Deferred + """ + return self._simple_delete_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="delete_device", + ) + @defer.inlineCallbacks def get_devices_by_user(self, user_id): """Retrieve all of a user's registered devices. diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 9a92b35361..935e82bf7a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -18,18 +18,31 @@ import re from twisted.internet import defer from synapse.api.errors import StoreError, Codes - -from ._base import SQLBaseStore +from synapse.storage import background_updates from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -class RegistrationStore(SQLBaseStore): +class RegistrationStore(background_updates.BackgroundUpdateStore): def __init__(self, hs): super(RegistrationStore, self).__init__(hs) self.clock = hs.get_clock() + self.register_background_index_update( + "access_tokens_device_index", + index_name="access_tokens_device_id", + table="access_tokens", + columns=["user_id", "device_id"], + ) + + self.register_background_index_update( + "refresh_tokens_device_index", + index_name="refresh_tokens_device_id", + table="refresh_tokens", + columns=["user_id", "device_id"], + ) + @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. @@ -238,11 +251,16 @@ class RegistrationStore(SQLBaseStore): self.get_user_by_id.invalidate((user_id,)) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id, except_token_ids=[]): + def user_delete_access_tokens(self, user_id, except_token_ids=[], + device_id=None): def f(txn): sql = "SELECT token FROM access_tokens WHERE user_id = ?" clauses = [user_id] + if device_id is not None: + sql += " AND device_id = ?" + clauses.append(device_id) + if except_token_ids: sql += " AND id NOT IN (%s)" % ( ",".join(["?" for _ in except_token_ids]), diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/schema/delta/33/access_tokens_device_index.sql new file mode 100644 index 0000000000..61ad3fe3e8 --- /dev/null +++ b/synapse/storage/schema/delta/33/access_tokens_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('access_tokens_device_index', '{}'); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql new file mode 100644 index 0000000000..bb225dafbf --- /dev/null +++ b/synapse/storage/schema/delta/33/refreshtoken_device_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('refresh_tokens_device_index', '{}'); diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 331aa13fed..214e722eb3 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -12,11 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse import types + from twisted.internet import defer +import synapse.api.errors import synapse.handlers.device + import synapse.storage +from synapse import types from tests import unittest, utils user1 = "@boris:aaa" @@ -27,7 +30,7 @@ class DeviceTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(DeviceTestCase, self).__init__(*args, **kwargs) self.store = None # type: synapse.storage.DataStore - self.handler = None # type: device.DeviceHandler + self.handler = None # type: synapse.handlers.device.DeviceHandler self.clock = None # type: utils.MockClock @defer.inlineCallbacks @@ -123,6 +126,21 @@ class DeviceTestCase(unittest.TestCase): "last_seen_ts": 3000000, }, res) + @defer.inlineCallbacks + def test_delete_device(self): + yield self._record_users() + + # delete the device + yield self.handler.delete_device(user1, "abc") + + # check the device was deleted + with self.assertRaises(synapse.api.errors.NotFoundError): + yield self.handler.get_device(user1, "abc") + + # we'd like to check the access token was invalidated, but that's a + # bit of a PITA. + + @defer.inlineCallbacks def _record_users(self): # check this works for both devices which have a recorded client_ip, diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 3bd7065e32..8ac56a1fb2 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -65,13 +65,16 @@ class RegisterRestServletTestCase(unittest.TestCase): self.registration_handler.appservice_register = Mock( return_value=user_id ) - self.auth_handler.issue_access_token = Mock(return_value=token) + self.auth_handler.get_login_tuple_for_user_id = Mock( + return_value=(token, "kermits_refresh_token") + ) (code, result) = yield self.servlet.on_POST(self.request) self.assertEquals(code, 200) det_data = { "user_id": user_id, "access_token": token, + "refresh_token": "kermits_refresh_token", "home_server": self.hs.hostname } self.assertDictContainsSubset(det_data, result) @@ -121,7 +124,9 @@ class RegisterRestServletTestCase(unittest.TestCase): "password": "monkey" }, None) self.registration_handler.register = Mock(return_value=(user_id, None)) - self.auth_handler.issue_access_token = Mock(return_value=token) + self.auth_handler.get_login_tuple_for_user_id = Mock( + return_value=(token, "kermits_refresh_token") + ) self.device_handler.check_device_registered = \ Mock(return_value=device_id) @@ -130,13 +135,14 @@ class RegisterRestServletTestCase(unittest.TestCase): det_data = { "user_id": user_id, "access_token": token, + "refresh_token": "kermits_refresh_token", "home_server": self.hs.hostname, "device_id": device_id, } self.assertDictContainsSubset(det_data, result) self.assertIn("refresh_token", result) - self.auth_handler.issue_access_token.assert_called_once_with( - user_id, device_id=device_id) + self.auth_handler.get_login_tuple_for_user_id( + user_id, device_id=device_id, initial_device_display_name=None) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False -- cgit 1.5.1 From 0682ca04b3ac0a3e148633d020b3248dbe98f13d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 8 Aug 2016 17:01:30 +0100 Subject: Fix CAS login Attempting to log in with CAS was giving a 500 error. --- synapse/rest/client/v1/login.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/rest/client/v1') diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 92fcae674a..d8c76a3465 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -427,6 +427,7 @@ class CasTicketServlet(ClientV1RestServlet): self.cas_server_url = hs.config.cas_server_url self.cas_service_url = hs.config.cas_service_url self.cas_required_attributes = hs.config.cas_required_attributes + self.auth_handler = hs.get_auth_handler() @defer.inlineCallbacks def on_GET(self, request): -- cgit 1.5.1 From 65666fedd5f60ec65fd86d9bbdff40fa67469025 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 8 Aug 2016 17:17:25 +0100 Subject: Clean up CAS login code Remove some apparently unused code. Clean up parse_cas_response, mostly to catch the exception if the CAS response isn't valid XML. --- synapse/rest/client/v1/login.py | 158 +++++++++------------------------------- 1 file changed, 33 insertions(+), 125 deletions(-) (limited to 'synapse/rest/client/v1') diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 92fcae674a..fef7910c4f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -54,10 +54,6 @@ class LoginRestServlet(ClientV1RestServlet): self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm self.cas_enabled = hs.config.cas_enabled - self.cas_server_url = hs.config.cas_server_url - self.cas_required_attributes = hs.config.cas_required_attributes - self.servername = hs.config.server_name - self.http_client = hs.get_simple_http_client() self.auth_handler = self.hs.get_auth_handler() self.device_handler = self.hs.get_device_handler() @@ -110,17 +106,6 @@ class LoginRestServlet(ClientV1RestServlet): LoginRestServlet.JWT_TYPE): result = yield self.do_jwt_login(login_submission) defer.returnValue(result) - # TODO Delete this after all CAS clients switch to token login instead - elif self.cas_enabled and (login_submission["type"] == - LoginRestServlet.CAS_TYPE): - uri = "%s/proxyValidate" % (self.cas_server_url,) - args = { - "ticket": login_submission["ticket"], - "service": login_submission["service"] - } - body = yield self.http_client.get_raw(uri, args) - result = yield self.do_cas_login(body) - defer.returnValue(result) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: result = yield self.do_token_login(login_submission) defer.returnValue(result) @@ -191,51 +176,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - # TODO Delete this after all CAS clients switch to token login instead - @defer.inlineCallbacks - def do_cas_login(self, cas_response_body): - user, attributes = self.parse_cas_response(cas_response_body) - - for required_attribute, required_value in self.cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in attributes: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - # Also need to check value - if required_value is not None: - actual_value = attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) - - user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.auth_handler - registered_user_id = yield auth_handler.check_user_exists(user_id) - if registered_user_id: - access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id( - registered_user_id - ) - ) - result = { - "user_id": registered_user_id, # may have changed - "access_token": access_token, - "refresh_token": refresh_token, - "home_server": self.hs.hostname, - } - - else: - user_id, access_token = ( - yield self.handlers.registration_handler.register(localpart=user) - ) - result = { - "user_id": user_id, # may have changed - "access_token": access_token, - "home_server": self.hs.hostname, - } - - defer.returnValue((200, result)) - @defer.inlineCallbacks def do_jwt_login(self, login_submission): token = login_submission.get("token", None) @@ -293,33 +233,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) - # TODO Delete this after all CAS clients switch to token login instead - def parse_cas_response(self, cas_response_body): - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not root[0].tag.endswith("authenticationSuccess"): - raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - attributes = {} - for attribute in child: - # ElementTree library expands the namespace in attribute tags - # to the full URL of the namespace. - # See (https://docs.python.org/2/library/xml.etree.elementtree.html) - # We don't care about namespace here and it will always be encased in - # curly braces, so we remove them. - if "}" in attribute.tag: - attributes[attribute.tag.split("}")[1]] = attribute.text - else: - attributes[attribute.tag] = attribute.text - if user is None or attributes is None: - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - - return (user, attributes) - def _register_device(self, user_id, login_submission): """Register a device for a user. @@ -384,18 +297,6 @@ class SAML2RestServlet(ClientV1RestServlet): defer.returnValue((200, {"status": "not_authenticated"})) -# TODO Delete this after all CAS clients switch to token login instead -class CasRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/login/cas", releases=()) - - def __init__(self, hs): - super(CasRestServlet, self).__init__(hs) - self.cas_server_url = hs.config.cas_server_url - - def on_GET(self, request): - return (200, {"serverUrl": self.cas_server_url}) - - class CasRedirectServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login/cas/redirect", releases=()) @@ -479,30 +380,39 @@ class CasTicketServlet(ClientV1RestServlet): return urlparse.urlunparse(url_parts) def parse_cas_response(self, cas_response_body): - root = ET.fromstring(cas_response_body) - if not root.tag.endswith("serviceResponse"): - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - if not root[0].tag.endswith("authenticationSuccess"): - raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED) - for child in root[0]: - if child.tag.endswith("user"): - user = child.text - if child.tag.endswith("attributes"): - attributes = {} - for attribute in child: - # ElementTree library expands the namespace in attribute tags - # to the full URL of the namespace. - # See (https://docs.python.org/2/library/xml.etree.elementtree.html) - # We don't care about namespace here and it will always be encased in - # curly braces, so we remove them. - if "}" in attribute.tag: - attributes[attribute.tag.split("}")[1]] = attribute.text - else: - attributes[attribute.tag] = attribute.text - if user is None or attributes is None: - raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) - - return (user, attributes) + user = None + attributes = None + try: + root = ET.fromstring(cas_response_body) + if not root.tag.endswith("serviceResponse"): + raise Exception("root of CAS response is not serviceResponse") + success = (root[0].tag.endswith("authenticationSuccess")) + for child in root[0]: + if child.tag.endswith("user"): + user = child.text + if child.tag.endswith("attributes"): + attributes = {} + for attribute in child: + # ElementTree library expands the namespace in + # attribute tags to the full URL of the namespace. + # We don't care about namespace here and it will always + # be encased in curly braces, so we remove them. + tag = attribute.tag + if "}" in tag: + tag = tag.split("}")[1] + attributes[tag] = attribute.text + if user is None: + raise Exception("CAS response does not contain user") + if attributes is None: + raise Exception("CAS response does not contain attributes") + except Exception: + logger.error("Error parsing CAS response", exc_info=1) + raise LoginError(401, "Invalid CAS response", + errcode=Codes.UNAUTHORIZED) + if not success: + raise LoginError(401, "Unsuccessful CAS response", + errcode=Codes.UNAUTHORIZED) + return user, attributes def register_servlets(hs, http_server): @@ -512,5 +422,3 @@ def register_servlets(hs, http_server): if hs.config.cas_enabled: CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) - CasRestServlet(hs).register(http_server) - # TODO PasswordResetRestServlet(hs).register(http_server) -- cgit 1.5.1