diff options
Diffstat (limited to 'synapse/rest/client/v1')
-rw-r--r-- | synapse/rest/client/v1/login.py | 67 | ||||
-rw-r--r-- | synapse/rest/client/v1/presence.py | 20 | ||||
-rw-r--r-- | synapse/rest/client/v1/push_rule.py | 6 | ||||
-rw-r--r-- | synapse/rest/client/v1/pusher.py | 101 | ||||
-rw-r--r-- | synapse/rest/client/v1/register.py | 71 | ||||
-rw-r--r-- | synapse/rest/client/v1/room.py | 56 |
6 files changed, 295 insertions, 26 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index fe593d07ce..8df9d10efa 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -43,20 +43,27 @@ class LoginRestServlet(ClientV1RestServlet): SAML2_TYPE = "m.login.saml2" CAS_TYPE = "m.login.cas" TOKEN_TYPE = "m.login.token" + JWT_TYPE = "m.login.jwt" def __init__(self, hs): super(LoginRestServlet, self).__init__(hs) self.idp_redirect_url = hs.config.saml2_idp_redirect_url self.password_enabled = hs.config.password_enabled self.saml2_enabled = hs.config.saml2_enabled + self.jwt_enabled = hs.config.jwt_enabled + self.jwt_secret = hs.config.jwt_secret + self.jwt_algorithm = hs.config.jwt_algorithm self.cas_enabled = hs.config.cas_enabled self.cas_server_url = hs.config.cas_server_url self.cas_required_attributes = hs.config.cas_required_attributes self.servername = hs.config.server_name self.http_client = hs.get_simple_http_client() + self.auth_handler = self.hs.get_auth_handler() def on_GET(self, request): flows = [] + if self.jwt_enabled: + flows.append({"type": LoginRestServlet.JWT_TYPE}) if self.saml2_enabled: flows.append({"type": LoginRestServlet.SAML2_TYPE}) if self.cas_enabled: @@ -98,6 +105,10 @@ class LoginRestServlet(ClientV1RestServlet): "uri": "%s%s" % (self.idp_redirect_url, relay_state) } defer.returnValue((200, result)) + elif self.jwt_enabled and (login_submission["type"] == + LoginRestServlet.JWT_TYPE): + result = yield self.do_jwt_login(login_submission) + defer.returnValue(result) # TODO Delete this after all CAS clients switch to token login instead elif self.cas_enabled and (login_submission["type"] == LoginRestServlet.CAS_TYPE): @@ -133,7 +144,7 @@ class LoginRestServlet(ClientV1RestServlet): user_id, self.hs.hostname ).to_string() - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_id, access_token, refresh_token = yield auth_handler.login_with_password( user_id=user_id, password=login_submission["password"]) @@ -150,7 +161,7 @@ class LoginRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def do_token_login(self, login_submission): token = login_submission['token'] - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) @@ -184,7 +195,7 @@ class LoginRestServlet(ClientV1RestServlet): raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_exists = yield auth_handler.does_user_exist(user_id) if user_exists: user_id, access_token, refresh_token = ( @@ -209,6 +220,54 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue((200, result)) + @defer.inlineCallbacks + def do_jwt_login(self, login_submission): + token = login_submission.get("token", None) + if token is None: + raise LoginError( + 401, "Token field for JWT is missing", + errcode=Codes.UNAUTHORIZED + ) + + import jwt + from jwt.exceptions import InvalidTokenError + + try: + payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm]) + except jwt.ExpiredSignatureError: + raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED) + except InvalidTokenError: + raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) + + user = payload.get("sub", None) + if user is None: + raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) + + user_id = UserID.create(user, self.hs.hostname).to_string() + auth_handler = self.auth_handler + user_exists = yield auth_handler.does_user_exist(user_id) + if user_exists: + user_id, access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id(user_id) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "refresh_token": refresh_token, + "home_server": self.hs.hostname, + } + else: + user_id, access_token = ( + yield self.handlers.registration_handler.register(localpart=user) + ) + result = { + "user_id": user_id, # may have changed + "access_token": access_token, + "home_server": self.hs.hostname, + } + + defer.returnValue((200, result)) + # TODO Delete this after all CAS clients switch to token login instead def parse_cas_response(self, cas_response_body): root = ET.fromstring(cas_response_body) @@ -354,7 +413,7 @@ class CasTicketServlet(ClientV1RestServlet): raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) user_id = UserID.create(user, self.hs.hostname).to_string() - auth_handler = self.handlers.auth_handler + auth_handler = self.auth_handler user_exists = yield auth_handler.does_user_exist(user_id) if not user_exists: user_id, _ = ( diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 27d9ed586b..eafdce865e 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -30,20 +30,24 @@ logger = logging.getLogger(__name__) class PresenceStatusRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status") + def __init__(self, hs): + super(PresenceStatusRestServlet, self).__init__(hs) + self.presence_handler = hs.get_presence_handler() + @defer.inlineCallbacks def on_GET(self, request, user_id): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if requester.user != user: - allowed = yield self.handlers.presence_handler.is_visible( + allowed = yield self.presence_handler.is_visible( observed_user=user, observer_user=requester.user, ) if not allowed: raise AuthError(403, "You are not allowed to see their presence.") - state = yield self.handlers.presence_handler.get_state(target_user=user) + state = yield self.presence_handler.get_state(target_user=user) defer.returnValue((200, state)) @@ -74,7 +78,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): except: raise SynapseError(400, "Unable to parse state") - yield self.handlers.presence_handler.set_state(user, state) + yield self.presence_handler.set_state(user, state) defer.returnValue((200, {})) @@ -85,6 +89,10 @@ class PresenceStatusRestServlet(ClientV1RestServlet): class PresenceListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/presence/list/(?P<user_id>[^/]*)") + def __init__(self, hs): + super(PresenceListRestServlet, self).__init__(hs) + self.presence_handler = hs.get_presence_handler() + @defer.inlineCallbacks def on_GET(self, request, user_id): requester = yield self.auth.get_user_by_req(request) @@ -96,7 +104,7 @@ class PresenceListRestServlet(ClientV1RestServlet): if requester.user != user: raise SynapseError(400, "Cannot get another user's presence list") - presence = yield self.handlers.presence_handler.get_presence_list( + presence = yield self.presence_handler.get_presence_list( observer_user=user, accepted=True ) @@ -123,7 +131,7 @@ class PresenceListRestServlet(ClientV1RestServlet): if len(u) == 0: continue invited_user = UserID.from_string(u) - yield self.handlers.presence_handler.send_presence_invite( + yield self.presence_handler.send_presence_invite( observer_user=user, observed_user=invited_user ) @@ -134,7 +142,7 @@ class PresenceListRestServlet(ClientV1RestServlet): if len(u) == 0: continue dropped_user = UserID.from_string(u) - yield self.handlers.presence_handler.drop( + yield self.presence_handler.drop( observer_user=user, observed_user=dropped_user ) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 02d837ee6a..6bb4821ec6 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -128,11 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet): # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rawrules = yield self.store.get_push_rules_for_user(user_id) + rules = yield self.store.get_push_rules_for_user(user_id) - enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) - - rules = format_push_rules_for_user(requester.user, rawrules, enabled_map) + rules = format_push_rules_for_user(requester.user, rules) path = request.postpath[1:] diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 9881f068c3..9a2ed6ed88 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -17,7 +17,11 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.push import PusherConfigException -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import ( + parse_json_object_from_request, parse_string, RestServlet +) +from synapse.http.server import finish_request +from synapse.api.errors import StoreError from .base import ClientV1RestServlet, client_path_patterns @@ -26,11 +30,48 @@ import logging logger = logging.getLogger(__name__) -class PusherRestServlet(ClientV1RestServlet): +class PushersRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/pushers$") + + def __init__(self, hs): + super(PushersRestServlet, self).__init__(hs) + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + user = requester.user + + pushers = yield self.hs.get_datastore().get_pushers_by_user_id( + user.to_string() + ) + + allowed_keys = [ + "app_display_name", + "app_id", + "data", + "device_display_name", + "kind", + "lang", + "profile_tag", + "pushkey", + ] + + for p in pushers: + for k, v in p.items(): + if k not in allowed_keys: + del p[k] + + defer.returnValue((200, {"pushers": pushers})) + + def on_OPTIONS(self, _): + return 200, {} + + +class PushersSetRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/pushers/set$") def __init__(self, hs): - super(PusherRestServlet, self).__init__(hs) + super(PushersSetRestServlet, self).__init__(hs) self.notifier = hs.get_notifier() @defer.inlineCallbacks @@ -99,5 +140,57 @@ class PusherRestServlet(ClientV1RestServlet): return 200, {} +class PushersRemoveRestServlet(RestServlet): + """ + To allow pusher to be delete by clicking a link (ie. GET request) + """ + PATTERNS = client_path_patterns("/pushers/remove$") + SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>" + + def __init__(self, hs): + super(RestServlet, self).__init__() + self.hs = hs + self.notifier = hs.get_notifier() + self.auth = hs.get_v1auth() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request, rights="delete_pusher") + user = requester.user + + app_id = parse_string(request, "app_id", required=True) + pushkey = parse_string(request, "pushkey", required=True) + + pusher_pool = self.hs.get_pusherpool() + + try: + yield pusher_pool.remove_pusher( + app_id=app_id, + pushkey=pushkey, + user_id=user.to_string(), + ) + except StoreError as se: + if se.code != 404: + # This is fine: they're already unsubscribed + raise + + self.notifier.on_new_replication_data() + + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Server", self.hs.version_string) + request.setHeader(b"Content-Length", b"%d" % ( + len(PushersRemoveRestServlet.SUCCESS_HTML), + )) + request.write(PushersRemoveRestServlet.SUCCESS_HTML) + finish_request(request) + defer.returnValue(None) + + def on_OPTIONS(self, _): + return 200, {} + + def register_servlets(hs, http_server): - PusherRestServlet(hs).register(http_server) + PushersRestServlet(hs).register(http_server) + PushersSetRestServlet(hs).register(http_server) + PushersRemoveRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index c6a2ef2ccc..e3f4fbb0bb 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -355,5 +355,76 @@ class RegisterRestServlet(ClientV1RestServlet): ) +class CreateUserRestServlet(ClientV1RestServlet): + """Handles user creation via a server-to-server interface + """ + + PATTERNS = client_path_patterns("/createUser$", releases=()) + + def __init__(self, hs): + super(CreateUserRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + self.direct_user_creation_max_duration = hs.config.user_creation_max_duration + + @defer.inlineCallbacks + def on_POST(self, request): + user_json = parse_json_object_from_request(request) + + if "access_token" not in request.args: + raise SynapseError(400, "Expected application service token.") + + app_service = yield self.store.get_app_service_by_token( + request.args["access_token"][0] + ) + if not app_service: + raise SynapseError(403, "Invalid application service token.") + + logger.debug("creating user: %s", user_json) + + response = yield self._do_create(user_json) + + defer.returnValue((200, response)) + + def on_OPTIONS(self, request): + return 403, {} + + @defer.inlineCallbacks + def _do_create(self, user_json): + yield run_on_reactor() + + if "localpart" not in user_json: + raise SynapseError(400, "Expected 'localpart' key.") + + if "displayname" not in user_json: + raise SynapseError(400, "Expected 'displayname' key.") + + if "duration_seconds" not in user_json: + raise SynapseError(400, "Expected 'duration_seconds' key.") + + localpart = user_json["localpart"].encode("utf-8") + displayname = user_json["displayname"].encode("utf-8") + duration_seconds = 0 + try: + duration_seconds = int(user_json["duration_seconds"]) + except ValueError: + raise SynapseError(400, "Failed to parse 'duration_seconds'") + if duration_seconds > self.direct_user_creation_max_duration: + duration_seconds = self.direct_user_creation_max_duration + + handler = self.handlers.registration_handler + user_id, token = yield handler.get_or_create_user( + localpart=localpart, + displayname=displayname, + duration_seconds=duration_seconds + ) + + defer.returnValue({ + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname, + }) + + def register_servlets(hs, http_server): RegisterRestServlet(hs).register(http_server) + CreateUserRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index a1fa7daf79..db52a1fc39 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -232,7 +232,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet): if RoomID.is_valid(room_identifier): room_id = room_identifier - remote_room_hosts = None + try: + remote_room_hosts = request.args["server_name"] + except: + remote_room_hosts = None elif RoomAlias.is_valid(room_identifier): handler = self.handlers.room_member_handler room_alias = RoomAlias.from_string(room_identifier) @@ -276,8 +279,9 @@ class PublicRoomListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - handler = self.handlers.room_list_handler - data = yield handler.get_public_room_list() + handler = self.hs.get_room_list_handler() + data = yield handler.get_aggregated_public_room_list() + defer.returnValue((200, data)) @@ -405,6 +409,42 @@ class RoomEventContext(ClientV1RestServlet): defer.returnValue((200, results)) +class RoomForgetRestServlet(ClientV1RestServlet): + def register(self, http_server): + PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") + register_txn_path(self, PATTERNS, http_server) + + @defer.inlineCallbacks + def on_POST(self, request, room_id, txn_id=None): + requester = yield self.auth.get_user_by_req( + request, + allow_guest=False, + ) + + yield self.handlers.room_member_handler.forget( + user=requester.user, + room_id=room_id, + ) + + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_PUT(self, request, room_id, txn_id): + try: + defer.returnValue( + self.txns.get_client_transaction(request, txn_id) + ) + except KeyError: + pass + + response = yield self.on_POST( + request, room_id, txn_id + ) + + self.txns.store_client_transaction(request, txn_id, response) + defer.returnValue(response) + + # TODO: Needs unit testing class RoomMembershipRestServlet(ClientV1RestServlet): @@ -534,7 +574,8 @@ class RoomTypingRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomTypingRestServlet, self).__init__(hs) - self.presence_handler = hs.get_handlers().presence_handler + self.presence_handler = hs.get_presence_handler() + self.typing_handler = hs.get_typing_handler() @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): @@ -545,19 +586,17 @@ class RoomTypingRestServlet(ClientV1RestServlet): content = parse_json_object_from_request(request) - typing_handler = self.handlers.typing_notification_handler - yield self.presence_handler.bump_presence_active_time(requester.user) if content["typing"]: - yield typing_handler.started_typing( + yield self.typing_handler.started_typing( target_user=target_user, auth_user=requester.user, room_id=room_id, timeout=content.get("timeout", 30000), ) else: - yield typing_handler.stopped_typing( + yield self.typing_handler.stopped_typing( target_user=target_user, auth_user=requester.user, room_id=room_id, @@ -624,6 +663,7 @@ def register_servlets(hs, http_server): RoomMemberListRestServlet(hs).register(http_server) RoomMessageListRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) + RoomForgetRestServlet(hs).register(http_server) RoomMembershipRestServlet(hs).register(http_server) RoomSendEventRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server) |