summary refs log tree commit diff
path: root/synapse/rest/client/v1
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v1')
-rw-r--r--synapse/rest/client/v1/login.py67
-rw-r--r--synapse/rest/client/v1/presence.py20
-rw-r--r--synapse/rest/client/v1/push_rule.py6
-rw-r--r--synapse/rest/client/v1/pusher.py101
-rw-r--r--synapse/rest/client/v1/register.py71
-rw-r--r--synapse/rest/client/v1/room.py56
6 files changed, 295 insertions, 26 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index fe593d07ce..8df9d10efa 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -43,20 +43,27 @@ class LoginRestServlet(ClientV1RestServlet):
     SAML2_TYPE = "m.login.saml2"
     CAS_TYPE = "m.login.cas"
     TOKEN_TYPE = "m.login.token"
+    JWT_TYPE = "m.login.jwt"
 
     def __init__(self, hs):
         super(LoginRestServlet, self).__init__(hs)
         self.idp_redirect_url = hs.config.saml2_idp_redirect_url
         self.password_enabled = hs.config.password_enabled
         self.saml2_enabled = hs.config.saml2_enabled
+        self.jwt_enabled = hs.config.jwt_enabled
+        self.jwt_secret = hs.config.jwt_secret
+        self.jwt_algorithm = hs.config.jwt_algorithm
         self.cas_enabled = hs.config.cas_enabled
         self.cas_server_url = hs.config.cas_server_url
         self.cas_required_attributes = hs.config.cas_required_attributes
         self.servername = hs.config.server_name
         self.http_client = hs.get_simple_http_client()
+        self.auth_handler = self.hs.get_auth_handler()
 
     def on_GET(self, request):
         flows = []
+        if self.jwt_enabled:
+            flows.append({"type": LoginRestServlet.JWT_TYPE})
         if self.saml2_enabled:
             flows.append({"type": LoginRestServlet.SAML2_TYPE})
         if self.cas_enabled:
@@ -98,6 +105,10 @@ class LoginRestServlet(ClientV1RestServlet):
                     "uri": "%s%s" % (self.idp_redirect_url, relay_state)
                 }
                 defer.returnValue((200, result))
+            elif self.jwt_enabled and (login_submission["type"] ==
+                                       LoginRestServlet.JWT_TYPE):
+                result = yield self.do_jwt_login(login_submission)
+                defer.returnValue(result)
             # TODO Delete this after all CAS clients switch to token login instead
             elif self.cas_enabled and (login_submission["type"] ==
                                        LoginRestServlet.CAS_TYPE):
@@ -133,7 +144,7 @@ class LoginRestServlet(ClientV1RestServlet):
                 user_id, self.hs.hostname
             ).to_string()
 
-        auth_handler = self.handlers.auth_handler
+        auth_handler = self.auth_handler
         user_id, access_token, refresh_token = yield auth_handler.login_with_password(
             user_id=user_id,
             password=login_submission["password"])
@@ -150,7 +161,7 @@ class LoginRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def do_token_login(self, login_submission):
         token = login_submission['token']
-        auth_handler = self.handlers.auth_handler
+        auth_handler = self.auth_handler
         user_id = (
             yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
         )
@@ -184,7 +195,7 @@ class LoginRestServlet(ClientV1RestServlet):
                     raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
 
         user_id = UserID.create(user, self.hs.hostname).to_string()
-        auth_handler = self.handlers.auth_handler
+        auth_handler = self.auth_handler
         user_exists = yield auth_handler.does_user_exist(user_id)
         if user_exists:
             user_id, access_token, refresh_token = (
@@ -209,6 +220,54 @@ class LoginRestServlet(ClientV1RestServlet):
 
         defer.returnValue((200, result))
 
+    @defer.inlineCallbacks
+    def do_jwt_login(self, login_submission):
+        token = login_submission.get("token", None)
+        if token is None:
+            raise LoginError(
+                401, "Token field for JWT is missing",
+                errcode=Codes.UNAUTHORIZED
+            )
+
+        import jwt
+        from jwt.exceptions import InvalidTokenError
+
+        try:
+            payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
+        except jwt.ExpiredSignatureError:
+            raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
+        except InvalidTokenError:
+            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
+
+        user = payload.get("sub", None)
+        if user is None:
+            raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
+
+        user_id = UserID.create(user, self.hs.hostname).to_string()
+        auth_handler = self.auth_handler
+        user_exists = yield auth_handler.does_user_exist(user_id)
+        if user_exists:
+            user_id, access_token, refresh_token = (
+                yield auth_handler.get_login_tuple_for_user_id(user_id)
+            )
+            result = {
+                "user_id": user_id,  # may have changed
+                "access_token": access_token,
+                "refresh_token": refresh_token,
+                "home_server": self.hs.hostname,
+            }
+        else:
+            user_id, access_token = (
+                yield self.handlers.registration_handler.register(localpart=user)
+            )
+            result = {
+                "user_id": user_id,  # may have changed
+                "access_token": access_token,
+                "home_server": self.hs.hostname,
+            }
+
+        defer.returnValue((200, result))
+
     # TODO Delete this after all CAS clients switch to token login instead
     def parse_cas_response(self, cas_response_body):
         root = ET.fromstring(cas_response_body)
@@ -354,7 +413,7 @@ class CasTicketServlet(ClientV1RestServlet):
                     raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
 
         user_id = UserID.create(user, self.hs.hostname).to_string()
-        auth_handler = self.handlers.auth_handler
+        auth_handler = self.auth_handler
         user_exists = yield auth_handler.does_user_exist(user_id)
         if not user_exists:
             user_id, _ = (
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 27d9ed586b..eafdce865e 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -30,20 +30,24 @@ logger = logging.getLogger(__name__)
 class PresenceStatusRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
 
+    def __init__(self, hs):
+        super(PresenceStatusRestServlet, self).__init__(hs)
+        self.presence_handler = hs.get_presence_handler()
+
     @defer.inlineCallbacks
     def on_GET(self, request, user_id):
         requester = yield self.auth.get_user_by_req(request)
         user = UserID.from_string(user_id)
 
         if requester.user != user:
-            allowed = yield self.handlers.presence_handler.is_visible(
+            allowed = yield self.presence_handler.is_visible(
                 observed_user=user, observer_user=requester.user,
             )
 
             if not allowed:
                 raise AuthError(403, "You are not allowed to see their presence.")
 
-        state = yield self.handlers.presence_handler.get_state(target_user=user)
+        state = yield self.presence_handler.get_state(target_user=user)
 
         defer.returnValue((200, state))
 
@@ -74,7 +78,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
         except:
             raise SynapseError(400, "Unable to parse state")
 
-        yield self.handlers.presence_handler.set_state(user, state)
+        yield self.presence_handler.set_state(user, state)
 
         defer.returnValue((200, {}))
 
@@ -85,6 +89,10 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
 class PresenceListRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/presence/list/(?P<user_id>[^/]*)")
 
+    def __init__(self, hs):
+        super(PresenceListRestServlet, self).__init__(hs)
+        self.presence_handler = hs.get_presence_handler()
+
     @defer.inlineCallbacks
     def on_GET(self, request, user_id):
         requester = yield self.auth.get_user_by_req(request)
@@ -96,7 +104,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
         if requester.user != user:
             raise SynapseError(400, "Cannot get another user's presence list")
 
-        presence = yield self.handlers.presence_handler.get_presence_list(
+        presence = yield self.presence_handler.get_presence_list(
             observer_user=user, accepted=True
         )
 
@@ -123,7 +131,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
                 if len(u) == 0:
                     continue
                 invited_user = UserID.from_string(u)
-                yield self.handlers.presence_handler.send_presence_invite(
+                yield self.presence_handler.send_presence_invite(
                     observer_user=user, observed_user=invited_user
                 )
 
@@ -134,7 +142,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
                 if len(u) == 0:
                     continue
                 dropped_user = UserID.from_string(u)
-                yield self.handlers.presence_handler.drop(
+                yield self.presence_handler.drop(
                     observer_user=user, observed_user=dropped_user
                 )
 
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 02d837ee6a..6bb4821ec6 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -128,11 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
         # we build up the full structure and then decide which bits of it
         # to send which means doing unnecessary work sometimes but is
         # is probably not going to make a whole lot of difference
-        rawrules = yield self.store.get_push_rules_for_user(user_id)
+        rules = yield self.store.get_push_rules_for_user(user_id)
 
-        enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
-
-        rules = format_push_rules_for_user(requester.user, rawrules, enabled_map)
+        rules = format_push_rules_for_user(requester.user, rules)
 
         path = request.postpath[1:]
 
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 9881f068c3..9a2ed6ed88 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -17,7 +17,11 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, Codes
 from synapse.push import PusherConfigException
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import (
+    parse_json_object_from_request, parse_string, RestServlet
+)
+from synapse.http.server import finish_request
+from synapse.api.errors import StoreError
 
 from .base import ClientV1RestServlet, client_path_patterns
 
@@ -26,11 +30,48 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class PusherRestServlet(ClientV1RestServlet):
+class PushersRestServlet(ClientV1RestServlet):
+    PATTERNS = client_path_patterns("/pushers$")
+
+    def __init__(self, hs):
+        super(PushersRestServlet, self).__init__(hs)
+
+    @defer.inlineCallbacks
+    def on_GET(self, request):
+        requester = yield self.auth.get_user_by_req(request)
+        user = requester.user
+
+        pushers = yield self.hs.get_datastore().get_pushers_by_user_id(
+            user.to_string()
+        )
+
+        allowed_keys = [
+            "app_display_name",
+            "app_id",
+            "data",
+            "device_display_name",
+            "kind",
+            "lang",
+            "profile_tag",
+            "pushkey",
+        ]
+
+        for p in pushers:
+            for k, v in p.items():
+                if k not in allowed_keys:
+                    del p[k]
+
+        defer.returnValue((200, {"pushers": pushers}))
+
+    def on_OPTIONS(self, _):
+        return 200, {}
+
+
+class PushersSetRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/pushers/set$")
 
     def __init__(self, hs):
-        super(PusherRestServlet, self).__init__(hs)
+        super(PushersSetRestServlet, self).__init__(hs)
         self.notifier = hs.get_notifier()
 
     @defer.inlineCallbacks
@@ -99,5 +140,57 @@ class PusherRestServlet(ClientV1RestServlet):
         return 200, {}
 
 
+class PushersRemoveRestServlet(RestServlet):
+    """
+    To allow pusher to be delete by clicking a link (ie. GET request)
+    """
+    PATTERNS = client_path_patterns("/pushers/remove$")
+    SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
+
+    def __init__(self, hs):
+        super(RestServlet, self).__init__()
+        self.hs = hs
+        self.notifier = hs.get_notifier()
+        self.auth = hs.get_v1auth()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request):
+        requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
+        user = requester.user
+
+        app_id = parse_string(request, "app_id", required=True)
+        pushkey = parse_string(request, "pushkey", required=True)
+
+        pusher_pool = self.hs.get_pusherpool()
+
+        try:
+            yield pusher_pool.remove_pusher(
+                app_id=app_id,
+                pushkey=pushkey,
+                user_id=user.to_string(),
+            )
+        except StoreError as se:
+            if se.code != 404:
+                # This is fine: they're already unsubscribed
+                raise
+
+        self.notifier.on_new_replication_data()
+
+        request.setResponseCode(200)
+        request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+        request.setHeader(b"Server", self.hs.version_string)
+        request.setHeader(b"Content-Length", b"%d" % (
+            len(PushersRemoveRestServlet.SUCCESS_HTML),
+        ))
+        request.write(PushersRemoveRestServlet.SUCCESS_HTML)
+        finish_request(request)
+        defer.returnValue(None)
+
+    def on_OPTIONS(self, _):
+        return 200, {}
+
+
 def register_servlets(hs, http_server):
-    PusherRestServlet(hs).register(http_server)
+    PushersRestServlet(hs).register(http_server)
+    PushersSetRestServlet(hs).register(http_server)
+    PushersRemoveRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index c6a2ef2ccc..e3f4fbb0bb 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -355,5 +355,76 @@ class RegisterRestServlet(ClientV1RestServlet):
             )
 
 
+class CreateUserRestServlet(ClientV1RestServlet):
+    """Handles user creation via a server-to-server interface
+    """
+
+    PATTERNS = client_path_patterns("/createUser$", releases=())
+
+    def __init__(self, hs):
+        super(CreateUserRestServlet, self).__init__(hs)
+        self.store = hs.get_datastore()
+        self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        user_json = parse_json_object_from_request(request)
+
+        if "access_token" not in request.args:
+            raise SynapseError(400, "Expected application service token.")
+
+        app_service = yield self.store.get_app_service_by_token(
+            request.args["access_token"][0]
+        )
+        if not app_service:
+            raise SynapseError(403, "Invalid application service token.")
+
+        logger.debug("creating user: %s", user_json)
+
+        response = yield self._do_create(user_json)
+
+        defer.returnValue((200, response))
+
+    def on_OPTIONS(self, request):
+        return 403, {}
+
+    @defer.inlineCallbacks
+    def _do_create(self, user_json):
+        yield run_on_reactor()
+
+        if "localpart" not in user_json:
+            raise SynapseError(400, "Expected 'localpart' key.")
+
+        if "displayname" not in user_json:
+            raise SynapseError(400, "Expected 'displayname' key.")
+
+        if "duration_seconds" not in user_json:
+            raise SynapseError(400, "Expected 'duration_seconds' key.")
+
+        localpart = user_json["localpart"].encode("utf-8")
+        displayname = user_json["displayname"].encode("utf-8")
+        duration_seconds = 0
+        try:
+            duration_seconds = int(user_json["duration_seconds"])
+        except ValueError:
+            raise SynapseError(400, "Failed to parse 'duration_seconds'")
+        if duration_seconds > self.direct_user_creation_max_duration:
+            duration_seconds = self.direct_user_creation_max_duration
+
+        handler = self.handlers.registration_handler
+        user_id, token = yield handler.get_or_create_user(
+            localpart=localpart,
+            displayname=displayname,
+            duration_seconds=duration_seconds
+        )
+
+        defer.returnValue({
+            "user_id": user_id,
+            "access_token": token,
+            "home_server": self.hs.hostname,
+        })
+
+
 def register_servlets(hs, http_server):
     RegisterRestServlet(hs).register(http_server)
+    CreateUserRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index a1fa7daf79..db52a1fc39 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -232,7 +232,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
 
         if RoomID.is_valid(room_identifier):
             room_id = room_identifier
-            remote_room_hosts = None
+            try:
+                remote_room_hosts = request.args["server_name"]
+            except:
+                remote_room_hosts = None
         elif RoomAlias.is_valid(room_identifier):
             handler = self.handlers.room_member_handler
             room_alias = RoomAlias.from_string(room_identifier)
@@ -276,8 +279,9 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request):
-        handler = self.handlers.room_list_handler
-        data = yield handler.get_public_room_list()
+        handler = self.hs.get_room_list_handler()
+        data = yield handler.get_aggregated_public_room_list()
+
         defer.returnValue((200, data))
 
 
@@ -405,6 +409,42 @@ class RoomEventContext(ClientV1RestServlet):
         defer.returnValue((200, results))
 
 
+class RoomForgetRestServlet(ClientV1RestServlet):
+    def register(self, http_server):
+        PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
+        register_txn_path(self, PATTERNS, http_server)
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, room_id, txn_id=None):
+        requester = yield self.auth.get_user_by_req(
+            request,
+            allow_guest=False,
+        )
+
+        yield self.handlers.room_member_handler.forget(
+            user=requester.user,
+            room_id=room_id,
+        )
+
+        defer.returnValue((200, {}))
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, room_id, txn_id):
+        try:
+            defer.returnValue(
+                self.txns.get_client_transaction(request, txn_id)
+            )
+        except KeyError:
+            pass
+
+        response = yield self.on_POST(
+            request, room_id, txn_id
+        )
+
+        self.txns.store_client_transaction(request, txn_id, response)
+        defer.returnValue(response)
+
+
 # TODO: Needs unit testing
 class RoomMembershipRestServlet(ClientV1RestServlet):
 
@@ -534,7 +574,8 @@ class RoomTypingRestServlet(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(RoomTypingRestServlet, self).__init__(hs)
-        self.presence_handler = hs.get_handlers().presence_handler
+        self.presence_handler = hs.get_presence_handler()
+        self.typing_handler = hs.get_typing_handler()
 
     @defer.inlineCallbacks
     def on_PUT(self, request, room_id, user_id):
@@ -545,19 +586,17 @@ class RoomTypingRestServlet(ClientV1RestServlet):
 
         content = parse_json_object_from_request(request)
 
-        typing_handler = self.handlers.typing_notification_handler
-
         yield self.presence_handler.bump_presence_active_time(requester.user)
 
         if content["typing"]:
-            yield typing_handler.started_typing(
+            yield self.typing_handler.started_typing(
                 target_user=target_user,
                 auth_user=requester.user,
                 room_id=room_id,
                 timeout=content.get("timeout", 30000),
             )
         else:
-            yield typing_handler.stopped_typing(
+            yield self.typing_handler.stopped_typing(
                 target_user=target_user,
                 auth_user=requester.user,
                 room_id=room_id,
@@ -624,6 +663,7 @@ def register_servlets(hs, http_server):
     RoomMemberListRestServlet(hs).register(http_server)
     RoomMessageListRestServlet(hs).register(http_server)
     JoinRoomAliasServlet(hs).register(http_server)
+    RoomForgetRestServlet(hs).register(http_server)
     RoomMembershipRestServlet(hs).register(http_server)
     RoomSendEventRestServlet(hs).register(http_server)
     PublicRoomListRestServlet(hs).register(http_server)