summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/client/v1/login.py20
-rw-r--r--synapse/rest/client/v2_alpha/account.py8
-rw-r--r--synapse/rest/client/v2_alpha/keys.py100
-rw-r--r--synapse/rest/client/v2_alpha/register.py31
4 files changed, 115 insertions, 44 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 998d4d44c6..0d5eafd0fa 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -74,17 +74,23 @@ class LoginRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def do_password_login(self, login_submission):
-        if not login_submission["user"].startswith('@'):
-            login_submission["user"] = UserID.create(
-                login_submission["user"], self.hs.hostname).to_string()
+        if 'medium' in login_submission and 'address' in login_submission:
+            user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+                login_submission['medium'], login_submission['address']
+            )
+        else:
+            user_id = login_submission['user']
+
+        if not user_id.startswith('@'):
+            user_id = UserID.create(
+                user_id, self.hs.hostname).to_string()
 
-        handler = self.handlers.login_handler
-        token = yield handler.login(
-            user=login_submission["user"],
+        token = yield self.handlers.auth_handler.login_with_password(
+            user_id=user_id,
             password=login_submission["password"])
 
         result = {
-            "user_id": login_submission["user"],  # may have changed
+            "user_id": user_id,  # may have changed
             "access_token": token,
             "home_server": self.hs.hostname,
         }
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index b082140f1f..897c54b539 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -36,7 +36,6 @@ class PasswordRestServlet(RestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_handlers().auth_handler
-        self.login_handler = hs.get_handlers().login_handler
 
     @defer.inlineCallbacks
     def on_POST(self, request):
@@ -47,7 +46,7 @@ class PasswordRestServlet(RestServlet):
         authed, result, params = yield self.auth_handler.check_auth([
             [LoginType.PASSWORD],
             [LoginType.EMAIL_IDENTITY]
-        ], body)
+        ], body, self.hs.get_ip_from_request(request))
 
         if not authed:
             defer.returnValue((401, result))
@@ -79,7 +78,7 @@ class PasswordRestServlet(RestServlet):
             raise SynapseError(400, "", Codes.MISSING_PARAM)
         new_password = params['new_password']
 
-        yield self.login_handler.set_password(
+        yield self.auth_handler.set_password(
             user_id, new_password, None
         )
 
@@ -95,7 +94,6 @@ class ThreepidRestServlet(RestServlet):
     def __init__(self, hs):
         super(ThreepidRestServlet, self).__init__()
         self.hs = hs
-        self.login_handler = hs.get_handlers().login_handler
         self.identity_handler = hs.get_handlers().identity_handler
         self.auth = hs.get_auth()
 
@@ -135,7 +133,7 @@ class ThreepidRestServlet(RestServlet):
                 logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
                 raise SynapseError(500, "Invalid response from ID Server")
 
-        yield self.login_handler.add_threepid(
+        yield self.auth_handler.add_threepid(
             auth_user.to_string(),
             threepid['medium'],
             threepid['address'],
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 5f3a6207b5..718928eedd 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import RestServlet
+from synapse.types import UserID
 from syutil.jsonutil import encode_canonical_json
 
 from ._base import client_v2_pattern
@@ -164,45 +165,63 @@ class KeyQueryServlet(RestServlet):
         super(KeyQueryServlet, self).__init__()
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
+        self.federation = hs.get_replication_layer()
+        self.is_mine = hs.is_mine
 
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id):
-        logger.debug("onPOST")
         yield self.auth.get_user_by_req(request)
         try:
             body = json.loads(request.content.read())
         except:
             raise SynapseError(400, "Invalid key JSON")
-        query = []
-        for user_id, device_ids in body.get("device_keys", {}).items():
-            if not device_ids:
-                query.append((user_id, None))
-            else:
-                for device_id in device_ids:
-                    query.append((user_id, device_id))
-        results = yield self.store.get_e2e_device_keys(query)
-        defer.returnValue(self.json_result(request, results))
+        result = yield self.handle_request(body)
+        defer.returnValue(result)
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id):
         auth_user, client_info = yield self.auth.get_user_by_req(request)
         auth_user_id = auth_user.to_string()
-        if not user_id:
-            user_id = auth_user_id
-        if not device_id:
-            device_id = None
-        # Returns a map of user_id->device_id->json_bytes.
-        results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
-        defer.returnValue(self.json_result(request, results))
-
-    def json_result(self, request, results):
+        user_id = user_id if user_id else auth_user_id
+        device_ids = [device_id] if device_id else []
+        result = yield self.handle_request(
+            {"device_keys": {user_id: device_ids}}
+        )
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def handle_request(self, body):
+        local_query = []
+        remote_queries = {}
+        for user_id, device_ids in body.get("device_keys", {}).items():
+            user = UserID.from_string(user_id)
+            if self.is_mine(user):
+                if not device_ids:
+                    local_query.append((user_id, None))
+                else:
+                    for device_id in device_ids:
+                        local_query.append((user_id, device_id))
+            else:
+                remote_queries.setdefault(user.domain, {})[user_id] = list(
+                    device_ids
+                )
+        results = yield self.store.get_e2e_device_keys(local_query)
+
         json_result = {}
         for user_id, device_keys in results.items():
             for device_id, json_bytes in device_keys.items():
                 json_result.setdefault(user_id, {})[device_id] = json.loads(
                     json_bytes
                 )
-        return (200, {"device_keys": json_result})
+
+        for destination, device_keys in remote_queries.items():
+            remote_result = yield self.federation.query_client_keys(
+                destination, {"device_keys": device_keys}
+            )
+            for user_id, keys in remote_result["device_keys"].items():
+                if user_id in device_keys:
+                    json_result[user_id] = keys
+        defer.returnValue((200, {"device_keys": json_result}))
 
 
 class OneTimeKeyServlet(RestServlet):
@@ -236,14 +255,16 @@ class OneTimeKeyServlet(RestServlet):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
+        self.federation = hs.get_replication_layer()
+        self.is_mine = hs.is_mine
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id, algorithm):
         yield self.auth.get_user_by_req(request)
-        results = yield self.store.claim_e2e_one_time_keys(
-            [(user_id, device_id, algorithm)]
+        result = yield self.handle_request(
+            {"one_time_keys": {user_id: {device_id: algorithm}}}
         )
-        defer.returnValue(self.json_result(request, results))
+        defer.returnValue(result)
 
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id, algorithm):
@@ -252,14 +273,24 @@ class OneTimeKeyServlet(RestServlet):
             body = json.loads(request.content.read())
         except:
             raise SynapseError(400, "Invalid key JSON")
-        query = []
+        result = yield self.handle_request(body)
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def handle_request(self, body):
+        local_query = []
+        remote_queries = {}
         for user_id, device_keys in body.get("one_time_keys", {}).items():
-            for device_id, algorithm in device_keys.items():
-                query.append((user_id, device_id, algorithm))
-        results = yield self.store.claim_e2e_one_time_keys(query)
-        defer.returnValue(self.json_result(request, results))
+            user = UserID.from_string(user_id)
+            if self.is_mine(user):
+                for device_id, algorithm in device_keys.items():
+                    local_query.append((user_id, device_id, algorithm))
+            else:
+                remote_queries.setdefault(user.domain, {})[user_id] = (
+                    device_keys
+                )
+        results = yield self.store.claim_e2e_one_time_keys(local_query)
 
-    def json_result(self, request, results):
         json_result = {}
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
@@ -267,7 +298,16 @@ class OneTimeKeyServlet(RestServlet):
                     json_result.setdefault(user_id, {})[device_id] = {
                         key_id: json.loads(json_bytes)
                     }
-        return (200, {"one_time_keys": json_result})
+
+        for destination, device_keys in remote_queries.items():
+            remote_result = yield self.federation.claim_client_keys(
+                destination, {"one_time_keys": device_keys}
+            )
+            for user_id, keys in remote_result["one_time_keys"].items():
+                if user_id in device_keys:
+                    json_result[user_id] = keys
+
+        defer.returnValue((200, {"one_time_keys": json_result}))
 
 
 def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index b5926f9ca6..1ba2f29711 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -50,11 +50,15 @@ class RegisterRestServlet(RestServlet):
         self.auth_handler = hs.get_handlers().auth_handler
         self.registration_handler = hs.get_handlers().registration_handler
         self.identity_handler = hs.get_handlers().identity_handler
-        self.login_handler = hs.get_handlers().login_handler
 
     @defer.inlineCallbacks
     def on_POST(self, request):
         yield run_on_reactor()
+
+        if '/register/email/requestToken' in request.path:
+            ret = yield self.onEmailTokenRequest(request)
+            defer.returnValue(ret)
+
         body = parse_json_dict_from_request(request)
 
         # we do basic sanity checks here because the auth layer will store these
@@ -143,7 +147,7 @@ class RegisterRestServlet(RestServlet):
                 if reqd not in threepid:
                     logger.info("Can't add incomplete 3pid")
                 else:
-                    yield self.login_handler.add_threepid(
+                    yield self.auth_handler.add_threepid(
                         user_id,
                         threepid['medium'],
                         threepid['address'],
@@ -209,6 +213,29 @@ class RegisterRestServlet(RestServlet):
             "home_server": self.hs.hostname,
         }
 
+    @defer.inlineCallbacks
+    def onEmailTokenRequest(self, request):
+        body = parse_json_dict_from_request(request)
+
+        required = ['id_server', 'client_secret', 'email', 'send_attempt']
+        absent = []
+        for k in required:
+            if k not in body:
+                absent.append(k)
+
+        if len(absent) > 0:
+            raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+
+        existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
+            'email', body['email']
+        )
+
+        if existingUid is not None:
+            raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
+
+        ret = yield self.identity_handler.requestEmailToken(**body)
+        defer.returnValue((200, ret))
+
 
 def register_servlets(hs, http_server):
     RegisterRestServlet(hs).register(http_server)