summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/federation/federation_client.py30
-rw-r--r--synapse/federation/federation_server.py43
-rw-r--r--synapse/federation/transport/client.py70
-rw-r--r--synapse/federation/transport/server.py20
-rw-r--r--synapse/handlers/message.py18
-rw-r--r--synapse/handlers/presence.py89
-rw-r--r--synapse/metrics/__init__.py28
-rw-r--r--synapse/rest/client/v2_alpha/keys.py100
-rw-r--r--synapse/storage/presence.py32
-rw-r--r--synapse/storage/state.py2
-rw-r--r--synapse/util/caches/descriptors.py2
11 files changed, 384 insertions, 50 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 58a6d6a0ed..f5e346cdbc 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -134,6 +134,36 @@ class FederationClient(FederationBase):
             destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
         )
 
+    @log_function
+    def query_client_keys(self, destination, content):
+        """Query device keys for a device hosted on a remote server.
+
+        Args:
+            destination (str): Domain name of the remote homeserver
+            content (dict): The query content.
+
+        Returns:
+            a Deferred which will eventually yield a JSON object from the
+            response
+        """
+        sent_queries_counter.inc("client_device_keys")
+        return self.transport_layer.query_client_keys(destination, content)
+
+    @log_function
+    def claim_client_keys(self, destination, content):
+        """Claims one-time keys for a device hosted on a remote server.
+
+        Args:
+            destination (str): Domain name of the remote homeserver
+            content (dict): The query content.
+
+        Returns:
+            a Deferred which will eventually yield a JSON object from the
+            response
+        """
+        sent_queries_counter.inc("client_one_time_keys")
+        return self.transport_layer.claim_client_keys(destination, content)
+
     @defer.inlineCallbacks
     @log_function
     def backfill(self, dest, context, limit, extremities):
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index cd79e23f4b..725c6f3fa5 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -27,6 +27,7 @@ from synapse.api.errors import FederationError, SynapseError
 
 from synapse.crypto.event_signing import compute_event_signature
 
+import simplejson as json
 import logging
 
 
@@ -314,6 +315,48 @@ class FederationServer(FederationBase):
 
     @defer.inlineCallbacks
     @log_function
+    def on_query_client_keys(self, origin, content):
+        query = []
+        for user_id, device_ids in content.get("device_keys", {}).items():
+            if not device_ids:
+                query.append((user_id, None))
+            else:
+                for device_id in device_ids:
+                    query.append((user_id, device_id))
+
+        results = yield self.store.get_e2e_device_keys(query)
+
+        json_result = {}
+        for user_id, device_keys in results.items():
+            for device_id, json_bytes in device_keys.items():
+                json_result.setdefault(user_id, {})[device_id] = json.loads(
+                    json_bytes
+                )
+
+        defer.returnValue({"device_keys": json_result})
+
+    @defer.inlineCallbacks
+    @log_function
+    def on_claim_client_keys(self, origin, content):
+        query = []
+        for user_id, device_keys in content.get("one_time_keys", {}).items():
+            for device_id, algorithm in device_keys.items():
+                query.append((user_id, device_id, algorithm))
+
+        results = yield self.store.claim_e2e_one_time_keys(query)
+
+        json_result = {}
+        for user_id, device_keys in results.items():
+            for device_id, keys in device_keys.items():
+                for key_id, json_bytes in keys.items():
+                    json_result.setdefault(user_id, {})[device_id] = {
+                        key_id: json.loads(json_bytes)
+                    }
+
+        defer.returnValue({"one_time_keys": json_result})
+
+    @defer.inlineCallbacks
+    @log_function
     def on_get_missing_events(self, origin, room_id, earliest_events,
                               latest_events, limit, min_depth):
         missing_events = yield self.handler.on_get_missing_events(
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 610a4c3163..ced703364b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -224,6 +224,76 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
+    def query_client_keys(self, destination, query_content):
+        """Query the device keys for a list of user ids hosted on a remote
+        server.
+
+        Request:
+            {
+              "device_keys": {
+                "<user_id>": ["<device_id>"]
+            } }
+
+        Response:
+            {
+              "device_keys": {
+                "<user_id>": {
+                  "<device_id>": {...}
+            } } }
+
+        Args:
+            destination(str): The server to query.
+            query_content(dict): The user ids to query.
+        Returns:
+            A dict containg the device keys.
+        """
+        path = PREFIX + "/user/keys/query"
+
+        content = yield self.client.post_json(
+            destination=destination,
+            path=path,
+            data=query_content,
+        )
+        defer.returnValue(content)
+
+    @defer.inlineCallbacks
+    @log_function
+    def claim_client_keys(self, destination, query_content):
+        """Claim one-time keys for a list of devices hosted on a remote server.
+
+        Request:
+            {
+              "one_time_keys": {
+                "<user_id>": {
+                    "<device_id>": "<algorithm>"
+            } } }
+
+        Response:
+            {
+              "device_keys": {
+                "<user_id>": {
+                  "<device_id>": {
+                    "<algorithm>:<key_id>": "<key_base64>"
+            } } } }
+
+        Args:
+            destination(str): The server to query.
+            query_content(dict): The user ids to query.
+        Returns:
+            A dict containg the one-time keys.
+        """
+
+        path = PREFIX + "/user/keys/claim"
+
+        content = yield self.client.post_json(
+            destination=destination,
+            path=path,
+            data=query_content,
+        )
+        defer.returnValue(content)
+
+    @defer.inlineCallbacks
+    @log_function
     def get_missing_events(self, destination, room_id, earliest_events,
                            latest_events, limit, min_depth):
         path = PREFIX + "/get_missing_events/%s" % (room_id,)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index bad93c6b2f..36f250e1a3 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -325,6 +325,24 @@ class FederationInviteServlet(BaseFederationServlet):
         defer.returnValue((200, content))
 
 
+class FederationClientKeysQueryServlet(BaseFederationServlet):
+    PATH = "/user/keys/query"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query):
+        response = yield self.handler.on_query_client_keys(origin, content)
+        defer.returnValue((200, response))
+
+
+class FederationClientKeysClaimServlet(BaseFederationServlet):
+    PATH = "/user/keys/claim"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query):
+        response = yield self.handler.on_claim_client_keys(origin, content)
+        defer.returnValue((200, response))
+
+
 class FederationQueryAuthServlet(BaseFederationServlet):
     PATH = "/query_auth/([^/]*)/([^/]*)"
 
@@ -373,4 +391,6 @@ SERVLET_CLASSES = (
     FederationQueryAuthServlet,
     FederationGetMissingEventsServlet,
     FederationEventAuthServlet,
+    FederationClientKeysQueryServlet,
+    FederationClientKeysClaimServlet,
 )
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 29e81085d1..f12465fa2c 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -460,20 +460,14 @@ class MessageHandler(BaseHandler):
 
         @defer.inlineCallbacks
         def get_presence():
-            presence_defs = yield defer.DeferredList(
-                [
-                    presence_handler.get_state(
-                        target_user=UserID.from_string(m.user_id),
-                        auth_user=auth_user,
-                        as_event=True,
-                        check_auth=False,
-                    )
-                    for m in room_members
-                ],
-                consumeErrors=True,
+            states = yield presence_handler.get_states(
+                target_users=[UserID.from_string(m.user_id) for m in room_members],
+                auth_user=auth_user,
+                as_event=True,
+                check_auth=False,
             )
 
-            defer.returnValue([p for success, p in presence_defs if success])
+            defer.returnValue(states.values())
 
         receipts_handler = self.hs.get_handlers().receipts_handler
 
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 341a516da2..e91e81831e 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -192,6 +192,20 @@ class PresenceHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def get_state(self, target_user, auth_user, as_event=False, check_auth=True):
+        """Get the current presence state of the given user.
+
+        Args:
+            target_user (UserID): The user whose presence we want
+            auth_user (UserID): The user requesting the presence, used for
+                checking if said user is allowed to see the persence of the
+                `target_user`
+            as_event (bool): Format the return as an event or not?
+            check_auth (bool): Perform the auth checks or not?
+
+        Returns:
+            dict: The presence state of the `target_user`, whose format depends
+            on the `as_event` argument.
+        """
         if self.hs.is_mine(target_user):
             if check_auth:
                 visible = yield self.is_presence_visible(
@@ -233,6 +247,81 @@ class PresenceHandler(BaseHandler):
             defer.returnValue(state)
 
     @defer.inlineCallbacks
+    def get_states(self, target_users, auth_user, as_event=False, check_auth=True):
+        """A batched version of the `get_state` method that accepts a list of
+        `target_users`
+
+        Args:
+            target_users (list): The list of UserID's whose presence we want
+            auth_user (UserID): The user requesting the presence, used for
+                checking if said user is allowed to see the persence of the
+                `target_users`
+            as_event (bool): Format the return as an event or not?
+            check_auth (bool): Perform the auth checks or not?
+
+        Returns:
+            dict: A mapping from user -> presence_state
+        """
+        local_users, remote_users = partitionbool(
+            target_users,
+            lambda u: self.hs.is_mine(u)
+        )
+
+        if check_auth:
+            for user in local_users:
+                visible = yield self.is_presence_visible(
+                    observer_user=auth_user,
+                    observed_user=user
+                )
+
+                if not visible:
+                    raise SynapseError(404, "Presence information not visible")
+
+        results = {}
+        if local_users:
+            for user in local_users:
+                if user in self._user_cachemap:
+                    results[user] = self._user_cachemap[user].get_state()
+
+            local_to_user = {u.localpart: u for u in local_users}
+
+            states = yield self.store.get_presence_states(
+                [u.localpart for u in local_users if u not in results]
+            )
+
+            for local_part, state in states.items():
+                if state is None:
+                    continue
+                res = {"presence": state["state"]}
+                if "status_msg" in state and state["status_msg"]:
+                    res["status_msg"] = state["status_msg"]
+                results[local_to_user[local_part]] = res
+
+        for user in remote_users:
+            # TODO(paul): Have remote server send us permissions set
+            results[user] = self._get_or_offline_usercache(user).get_state()
+
+        for state in results.values():
+            if "last_active" in state:
+                state["last_active_ago"] = int(
+                    self.clock.time_msec() - state.pop("last_active")
+                )
+
+        if as_event:
+            for user, state in results.items():
+                content = state
+                content["user_id"] = user.to_string()
+
+                if "last_active" in content:
+                    content["last_active_ago"] = int(
+                        self._clock.time_msec() - content.pop("last_active")
+                    )
+
+                results[user] = {"type": "m.presence", "content": content}
+
+        defer.returnValue(results)
+
+    @defer.inlineCallbacks
     @log_function
     def set_state(self, target_user, auth_user, state):
         # return
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 0be9772991..d7bcad8a8a 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -158,18 +158,40 @@ def runUntilCurrentTimer(func):
 
     @functools.wraps(func)
     def f(*args, **kwargs):
-        pending_calls = len(reactor.getDelayedCalls())
+        now = reactor.seconds()
+        num_pending = 0
+
+        # _newTimedCalls is one long list of *all* pending calls. Below loop
+        # is based off of impl of reactor.runUntilCurrent
+        for delayed_call in reactor._newTimedCalls:
+            if delayed_call.time > now:
+                break
+
+            if delayed_call.delayed_time > 0:
+                continue
+
+            num_pending += 1
+
+        num_pending += len(reactor.threadCallQueue)
+
         start = time.time() * 1000
         ret = func(*args, **kwargs)
         end = time.time() * 1000
         tick_time.inc_by(end - start)
-        pending_calls_metric.inc_by(pending_calls)
+        pending_calls_metric.inc_by(num_pending)
         return ret
 
     return f
 
 
-if hasattr(reactor, "runUntilCurrent"):
+try:
+    # Ensure the reactor has all the attributes we expect
+    reactor.runUntilCurrent
+    reactor._newTimedCalls
+    reactor.threadCallQueue
+
     # runUntilCurrent is called when we have pending calls. It is called once
     # per iteratation after fd polling.
     reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
+except AttributeError:
+    pass
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 5f3a6207b5..718928eedd 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import RestServlet
+from synapse.types import UserID
 from syutil.jsonutil import encode_canonical_json
 
 from ._base import client_v2_pattern
@@ -164,45 +165,63 @@ class KeyQueryServlet(RestServlet):
         super(KeyQueryServlet, self).__init__()
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
+        self.federation = hs.get_replication_layer()
+        self.is_mine = hs.is_mine
 
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id):
-        logger.debug("onPOST")
         yield self.auth.get_user_by_req(request)
         try:
             body = json.loads(request.content.read())
         except:
             raise SynapseError(400, "Invalid key JSON")
-        query = []
-        for user_id, device_ids in body.get("device_keys", {}).items():
-            if not device_ids:
-                query.append((user_id, None))
-            else:
-                for device_id in device_ids:
-                    query.append((user_id, device_id))
-        results = yield self.store.get_e2e_device_keys(query)
-        defer.returnValue(self.json_result(request, results))
+        result = yield self.handle_request(body)
+        defer.returnValue(result)
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id):
         auth_user, client_info = yield self.auth.get_user_by_req(request)
         auth_user_id = auth_user.to_string()
-        if not user_id:
-            user_id = auth_user_id
-        if not device_id:
-            device_id = None
-        # Returns a map of user_id->device_id->json_bytes.
-        results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
-        defer.returnValue(self.json_result(request, results))
-
-    def json_result(self, request, results):
+        user_id = user_id if user_id else auth_user_id
+        device_ids = [device_id] if device_id else []
+        result = yield self.handle_request(
+            {"device_keys": {user_id: device_ids}}
+        )
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def handle_request(self, body):
+        local_query = []
+        remote_queries = {}
+        for user_id, device_ids in body.get("device_keys", {}).items():
+            user = UserID.from_string(user_id)
+            if self.is_mine(user):
+                if not device_ids:
+                    local_query.append((user_id, None))
+                else:
+                    for device_id in device_ids:
+                        local_query.append((user_id, device_id))
+            else:
+                remote_queries.setdefault(user.domain, {})[user_id] = list(
+                    device_ids
+                )
+        results = yield self.store.get_e2e_device_keys(local_query)
+
         json_result = {}
         for user_id, device_keys in results.items():
             for device_id, json_bytes in device_keys.items():
                 json_result.setdefault(user_id, {})[device_id] = json.loads(
                     json_bytes
                 )
-        return (200, {"device_keys": json_result})
+
+        for destination, device_keys in remote_queries.items():
+            remote_result = yield self.federation.query_client_keys(
+                destination, {"device_keys": device_keys}
+            )
+            for user_id, keys in remote_result["device_keys"].items():
+                if user_id in device_keys:
+                    json_result[user_id] = keys
+        defer.returnValue((200, {"device_keys": json_result}))
 
 
 class OneTimeKeyServlet(RestServlet):
@@ -236,14 +255,16 @@ class OneTimeKeyServlet(RestServlet):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
+        self.federation = hs.get_replication_layer()
+        self.is_mine = hs.is_mine
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id, algorithm):
         yield self.auth.get_user_by_req(request)
-        results = yield self.store.claim_e2e_one_time_keys(
-            [(user_id, device_id, algorithm)]
+        result = yield self.handle_request(
+            {"one_time_keys": {user_id: {device_id: algorithm}}}
         )
-        defer.returnValue(self.json_result(request, results))
+        defer.returnValue(result)
 
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id, algorithm):
@@ -252,14 +273,24 @@ class OneTimeKeyServlet(RestServlet):
             body = json.loads(request.content.read())
         except:
             raise SynapseError(400, "Invalid key JSON")
-        query = []
+        result = yield self.handle_request(body)
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def handle_request(self, body):
+        local_query = []
+        remote_queries = {}
         for user_id, device_keys in body.get("one_time_keys", {}).items():
-            for device_id, algorithm in device_keys.items():
-                query.append((user_id, device_id, algorithm))
-        results = yield self.store.claim_e2e_one_time_keys(query)
-        defer.returnValue(self.json_result(request, results))
+            user = UserID.from_string(user_id)
+            if self.is_mine(user):
+                for device_id, algorithm in device_keys.items():
+                    local_query.append((user_id, device_id, algorithm))
+            else:
+                remote_queries.setdefault(user.domain, {})[user_id] = (
+                    device_keys
+                )
+        results = yield self.store.claim_e2e_one_time_keys(local_query)
 
-    def json_result(self, request, results):
         json_result = {}
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
@@ -267,7 +298,16 @@ class OneTimeKeyServlet(RestServlet):
                     json_result.setdefault(user_id, {})[device_id] = {
                         key_id: json.loads(json_bytes)
                     }
-        return (200, {"one_time_keys": json_result})
+
+        for destination, device_keys in remote_queries.items():
+            remote_result = yield self.federation.claim_client_keys(
+                destination, {"one_time_keys": device_keys}
+            )
+            for user_id, keys in remote_result["one_time_keys"].items():
+                if user_id in device_keys:
+                    json_result[user_id] = keys
+
+        defer.returnValue((200, {"one_time_keys": json_result}))
 
 
 def register_servlets(hs, http_server):
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 4f91a2b87c..34ca3b9a54 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -14,19 +14,22 @@
 # limitations under the License.
 
 from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedList
 
 from twisted.internet import defer
 
 
 class PresenceStore(SQLBaseStore):
     def create_presence(self, user_localpart):
-        return self._simple_insert(
+        res = self._simple_insert(
             table="presence",
             values={"user_id": user_localpart},
             desc="create_presence",
         )
 
+        self.get_presence_state.invalidate((user_localpart,))
+        return res
+
     def has_presence_state(self, user_localpart):
         return self._simple_select_one(
             table="presence",
@@ -36,6 +39,7 @@ class PresenceStore(SQLBaseStore):
             desc="has_presence_state",
         )
 
+    @cached(max_entries=2000)
     def get_presence_state(self, user_localpart):
         return self._simple_select_one(
             table="presence",
@@ -44,8 +48,27 @@ class PresenceStore(SQLBaseStore):
             desc="get_presence_state",
         )
 
+    @cachedList(get_presence_state.cache, list_name="user_localparts")
+    def get_presence_states(self, user_localparts):
+        def f(txn):
+            results = {}
+            for user_localpart in user_localparts:
+                res = self._simple_select_one_txn(
+                    txn,
+                    table="presence",
+                    keyvalues={"user_id": user_localpart},
+                    retcols=["state", "status_msg", "mtime"],
+                    allow_none=True,
+                )
+                if res:
+                    results[user_localpart] = res
+
+            return results
+
+        return self.runInteraction("get_presence_states", f)
+
     def set_presence_state(self, user_localpart, new_state):
-        return self._simple_update_one(
+        res = self._simple_update_one(
             table="presence",
             keyvalues={"user_id": user_localpart},
             updatevalues={"state": new_state["state"],
@@ -54,6 +77,9 @@ class PresenceStore(SQLBaseStore):
             desc="set_presence_state",
         )
 
+        self.get_presence_state.invalidate((user_localpart,))
+        return res
+
     def allow_presence_visible(self, observed_localpart, observer_userid):
         return self._simple_insert(
             table="presence_allow_inbound",
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index ecb62e6dfd..ab3ad5a076 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -240,7 +240,7 @@ class StateStore(SQLBaseStore):
 
         defer.returnValue({event: event_to_state[event] for event in event_ids})
 
-    @cached(num_args=2, lru=True, max_entries=100000)
+    @cached(num_args=2, lru=True, max_entries=10000)
     def _get_state_group_for_event(self, room_id, event_id):
         return self._simple_select_one_onecol(
             table="event_to_state_groups",
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 83bfec2f02..362944bc51 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -293,7 +293,7 @@ class CacheListDescriptor(object):
                 # we can insert the new deferred into the cache.
                 for arg in missing:
                     observer = ret_d.observe()
-                    observer.addCallback(lambda r, arg: r[arg], arg)
+                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)
 
                     observer = ObservableDeferred(observer)