summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/client/v1/login.py5
-rw-r--r--synapse/rest/client/v2_alpha/account.py8
-rw-r--r--synapse/rest/client/v2_alpha/keys.py100
-rw-r--r--synapse/rest/client/v2_alpha/register.py3
4 files changed, 76 insertions, 40 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 998d4d44c6..694072693d 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -78,9 +78,8 @@ class LoginRestServlet(ClientV1RestServlet):
             login_submission["user"] = UserID.create(
                 login_submission["user"], self.hs.hostname).to_string()
 
-        handler = self.handlers.login_handler
-        token = yield handler.login(
-            user=login_submission["user"],
+        token = yield self.handlers.auth_handler.login_with_password(
+            user_id=login_submission["user"],
             password=login_submission["password"])
 
         result = {
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index b082140f1f..897c54b539 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -36,7 +36,6 @@ class PasswordRestServlet(RestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_handlers().auth_handler
-        self.login_handler = hs.get_handlers().login_handler
 
     @defer.inlineCallbacks
     def on_POST(self, request):
@@ -47,7 +46,7 @@ class PasswordRestServlet(RestServlet):
         authed, result, params = yield self.auth_handler.check_auth([
             [LoginType.PASSWORD],
             [LoginType.EMAIL_IDENTITY]
-        ], body)
+        ], body, self.hs.get_ip_from_request(request))
 
         if not authed:
             defer.returnValue((401, result))
@@ -79,7 +78,7 @@ class PasswordRestServlet(RestServlet):
             raise SynapseError(400, "", Codes.MISSING_PARAM)
         new_password = params['new_password']
 
-        yield self.login_handler.set_password(
+        yield self.auth_handler.set_password(
             user_id, new_password, None
         )
 
@@ -95,7 +94,6 @@ class ThreepidRestServlet(RestServlet):
     def __init__(self, hs):
         super(ThreepidRestServlet, self).__init__()
         self.hs = hs
-        self.login_handler = hs.get_handlers().login_handler
         self.identity_handler = hs.get_handlers().identity_handler
         self.auth = hs.get_auth()
 
@@ -135,7 +133,7 @@ class ThreepidRestServlet(RestServlet):
                 logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
                 raise SynapseError(500, "Invalid response from ID Server")
 
-        yield self.login_handler.add_threepid(
+        yield self.auth_handler.add_threepid(
             auth_user.to_string(),
             threepid['medium'],
             threepid['address'],
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 5f3a6207b5..718928eedd 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import RestServlet
+from synapse.types import UserID
 from syutil.jsonutil import encode_canonical_json
 
 from ._base import client_v2_pattern
@@ -164,45 +165,63 @@ class KeyQueryServlet(RestServlet):
         super(KeyQueryServlet, self).__init__()
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
+        self.federation = hs.get_replication_layer()
+        self.is_mine = hs.is_mine
 
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id):
-        logger.debug("onPOST")
         yield self.auth.get_user_by_req(request)
         try:
             body = json.loads(request.content.read())
         except:
             raise SynapseError(400, "Invalid key JSON")
-        query = []
-        for user_id, device_ids in body.get("device_keys", {}).items():
-            if not device_ids:
-                query.append((user_id, None))
-            else:
-                for device_id in device_ids:
-                    query.append((user_id, device_id))
-        results = yield self.store.get_e2e_device_keys(query)
-        defer.returnValue(self.json_result(request, results))
+        result = yield self.handle_request(body)
+        defer.returnValue(result)
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id):
         auth_user, client_info = yield self.auth.get_user_by_req(request)
         auth_user_id = auth_user.to_string()
-        if not user_id:
-            user_id = auth_user_id
-        if not device_id:
-            device_id = None
-        # Returns a map of user_id->device_id->json_bytes.
-        results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
-        defer.returnValue(self.json_result(request, results))
-
-    def json_result(self, request, results):
+        user_id = user_id if user_id else auth_user_id
+        device_ids = [device_id] if device_id else []
+        result = yield self.handle_request(
+            {"device_keys": {user_id: device_ids}}
+        )
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def handle_request(self, body):
+        local_query = []
+        remote_queries = {}
+        for user_id, device_ids in body.get("device_keys", {}).items():
+            user = UserID.from_string(user_id)
+            if self.is_mine(user):
+                if not device_ids:
+                    local_query.append((user_id, None))
+                else:
+                    for device_id in device_ids:
+                        local_query.append((user_id, device_id))
+            else:
+                remote_queries.setdefault(user.domain, {})[user_id] = list(
+                    device_ids
+                )
+        results = yield self.store.get_e2e_device_keys(local_query)
+
         json_result = {}
         for user_id, device_keys in results.items():
             for device_id, json_bytes in device_keys.items():
                 json_result.setdefault(user_id, {})[device_id] = json.loads(
                     json_bytes
                 )
-        return (200, {"device_keys": json_result})
+
+        for destination, device_keys in remote_queries.items():
+            remote_result = yield self.federation.query_client_keys(
+                destination, {"device_keys": device_keys}
+            )
+            for user_id, keys in remote_result["device_keys"].items():
+                if user_id in device_keys:
+                    json_result[user_id] = keys
+        defer.returnValue((200, {"device_keys": json_result}))
 
 
 class OneTimeKeyServlet(RestServlet):
@@ -236,14 +255,16 @@ class OneTimeKeyServlet(RestServlet):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
+        self.federation = hs.get_replication_layer()
+        self.is_mine = hs.is_mine
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id, algorithm):
         yield self.auth.get_user_by_req(request)
-        results = yield self.store.claim_e2e_one_time_keys(
-            [(user_id, device_id, algorithm)]
+        result = yield self.handle_request(
+            {"one_time_keys": {user_id: {device_id: algorithm}}}
         )
-        defer.returnValue(self.json_result(request, results))
+        defer.returnValue(result)
 
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id, algorithm):
@@ -252,14 +273,24 @@ class OneTimeKeyServlet(RestServlet):
             body = json.loads(request.content.read())
         except:
             raise SynapseError(400, "Invalid key JSON")
-        query = []
+        result = yield self.handle_request(body)
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def handle_request(self, body):
+        local_query = []
+        remote_queries = {}
         for user_id, device_keys in body.get("one_time_keys", {}).items():
-            for device_id, algorithm in device_keys.items():
-                query.append((user_id, device_id, algorithm))
-        results = yield self.store.claim_e2e_one_time_keys(query)
-        defer.returnValue(self.json_result(request, results))
+            user = UserID.from_string(user_id)
+            if self.is_mine(user):
+                for device_id, algorithm in device_keys.items():
+                    local_query.append((user_id, device_id, algorithm))
+            else:
+                remote_queries.setdefault(user.domain, {})[user_id] = (
+                    device_keys
+                )
+        results = yield self.store.claim_e2e_one_time_keys(local_query)
 
-    def json_result(self, request, results):
         json_result = {}
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
@@ -267,7 +298,16 @@ class OneTimeKeyServlet(RestServlet):
                     json_result.setdefault(user_id, {})[device_id] = {
                         key_id: json.loads(json_bytes)
                     }
-        return (200, {"one_time_keys": json_result})
+
+        for destination, device_keys in remote_queries.items():
+            remote_result = yield self.federation.claim_client_keys(
+                destination, {"one_time_keys": device_keys}
+            )
+            for user_id, keys in remote_result["one_time_keys"].items():
+                if user_id in device_keys:
+                    json_result[user_id] = keys
+
+        defer.returnValue((200, {"one_time_keys": json_result}))
 
 
 def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 254c5f1ddf..1ba2f29711 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -50,7 +50,6 @@ class RegisterRestServlet(RestServlet):
         self.auth_handler = hs.get_handlers().auth_handler
         self.registration_handler = hs.get_handlers().registration_handler
         self.identity_handler = hs.get_handlers().identity_handler
-        self.login_handler = hs.get_handlers().login_handler
 
     @defer.inlineCallbacks
     def on_POST(self, request):
@@ -148,7 +147,7 @@ class RegisterRestServlet(RestServlet):
                 if reqd not in threepid:
                     logger.info("Can't add incomplete 3pid")
                 else:
-                    yield self.login_handler.add_threepid(
+                    yield self.auth_handler.add_threepid(
                         user_id,
                         threepid['medium'],
                         threepid['address'],