summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/e2e_keys.py184
-rw-r--r--synapse/handlers/presence.py24
-rw-r--r--synapse/handlers/receipts.py1
-rw-r--r--synapse/handlers/room.py6
-rw-r--r--synapse/handlers/typing.py1
5 files changed, 177 insertions, 39 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 2c7bfd91ed..fd11935b40 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -13,14 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import collections
-import json
+import ujson as json
 import logging
 
+from canonicaljson import encode_canonical_json
 from twisted.internet import defer
 
-from synapse.api import errors
-import synapse.types
+from synapse.api.errors import SynapseError, CodeMessageException
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
 
 logger = logging.getLogger(__name__)
 
@@ -29,8 +31,9 @@ class E2eKeysHandler(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
         self.federation = hs.get_replication_layer()
+        self.device_handler = hs.get_device_handler()
         self.is_mine_id = hs.is_mine_id
-        self.server_name = hs.hostname
+        self.clock = hs.get_clock()
 
         # doesn't really work as part of the generic query API, because the
         # query request requires an object POST, but we abuse the
@@ -40,7 +43,7 @@ class E2eKeysHandler(object):
         )
 
     @defer.inlineCallbacks
-    def query_devices(self, query_body):
+    def query_devices(self, query_body, timeout):
         """ Handle a device key query from a client
 
         {
@@ -63,27 +66,60 @@ class E2eKeysHandler(object):
 
         # separate users by domain.
         # make a map from domain to user_id to device_ids
-        queries_by_domain = collections.defaultdict(dict)
+        local_query = {}
+        remote_queries = {}
+
         for user_id, device_ids in device_keys_query.items():
-            user = synapse.types.UserID.from_string(user_id)
-            queries_by_domain[user.domain][user_id] = device_ids
+            if self.is_mine_id(user_id):
+                local_query[user_id] = device_ids
+            else:
+                domain = get_domain_from_id(user_id)
+                remote_queries.setdefault(domain, {})[user_id] = device_ids
 
         # do the queries
-        # TODO: do these in parallel
+        failures = {}
         results = {}
-        for destination, destination_query in queries_by_domain.items():
-            if destination == self.server_name:
-                res = yield self.query_local_devices(destination_query)
-            else:
-                res = yield self.federation.query_client_keys(
-                    destination, {"device_keys": destination_query}
-                )
-                res = res["device_keys"]
-            for user_id, keys in res.items():
-                if user_id in destination_query:
+        if local_query:
+            local_result = yield self.query_local_devices(local_query)
+            for user_id, keys in local_result.items():
+                if user_id in local_query:
                     results[user_id] = keys
 
-        defer.returnValue((200, {"device_keys": results}))
+        @defer.inlineCallbacks
+        def do_remote_query(destination):
+            destination_query = remote_queries[destination]
+            try:
+                limiter = yield get_retry_limiter(
+                    destination, self.clock, self.store
+                )
+                with limiter:
+                    remote_result = yield self.federation.query_client_keys(
+                        destination,
+                        {"device_keys": destination_query},
+                        timeout=timeout
+                    )
+
+                for user_id, keys in remote_result["device_keys"].items():
+                    if user_id in destination_query:
+                        results[user_id] = keys
+
+            except CodeMessageException as e:
+                failures[destination] = {
+                    "status": e.code, "message": e.message
+                }
+            except NotRetryingDestination as e:
+                failures[destination] = {
+                    "status": 503, "message": "Not ready for retry",
+                }
+
+        yield preserve_context_over_deferred(defer.gatherResults([
+            preserve_fn(do_remote_query)(destination)
+            for destination in remote_queries
+        ]))
+
+        defer.returnValue({
+            "device_keys": results, "failures": failures,
+        })
 
     @defer.inlineCallbacks
     def query_local_devices(self, query):
@@ -104,7 +140,7 @@ class E2eKeysHandler(object):
             if not self.is_mine_id(user_id):
                 logger.warning("Request for keys for non-local user %s",
                                user_id)
-                raise errors.SynapseError(400, "Not a user here")
+                raise SynapseError(400, "Not a user here")
 
             if not device_ids:
                 local_query.append((user_id, None))
@@ -137,3 +173,107 @@ class E2eKeysHandler(object):
         device_keys_query = query_body.get("device_keys", {})
         res = yield self.query_local_devices(device_keys_query)
         defer.returnValue({"device_keys": res})
+
+    @defer.inlineCallbacks
+    def claim_one_time_keys(self, query, timeout):
+        local_query = []
+        remote_queries = {}
+
+        for user_id, device_keys in query.get("one_time_keys", {}).items():
+            if self.is_mine_id(user_id):
+                for device_id, algorithm in device_keys.items():
+                    local_query.append((user_id, device_id, algorithm))
+            else:
+                domain = get_domain_from_id(user_id)
+                remote_queries.setdefault(domain, {})[user_id] = device_keys
+
+        results = yield self.store.claim_e2e_one_time_keys(local_query)
+
+        json_result = {}
+        failures = {}
+        for user_id, device_keys in results.items():
+            for device_id, keys in device_keys.items():
+                for key_id, json_bytes in keys.items():
+                    json_result.setdefault(user_id, {})[device_id] = {
+                        key_id: json.loads(json_bytes)
+                    }
+
+        @defer.inlineCallbacks
+        def claim_client_keys(destination):
+            device_keys = remote_queries[destination]
+            try:
+                limiter = yield get_retry_limiter(
+                    destination, self.clock, self.store
+                )
+                with limiter:
+                    remote_result = yield self.federation.claim_client_keys(
+                        destination,
+                        {"one_time_keys": device_keys},
+                        timeout=timeout
+                    )
+                    for user_id, keys in remote_result["one_time_keys"].items():
+                        if user_id in device_keys:
+                            json_result[user_id] = keys
+            except CodeMessageException as e:
+                failures[destination] = {
+                    "status": e.code, "message": e.message
+                }
+            except NotRetryingDestination as e:
+                failures[destination] = {
+                    "status": 503, "message": "Not ready for retry",
+                }
+
+        yield preserve_context_over_deferred(defer.gatherResults([
+            preserve_fn(claim_client_keys)(destination)
+            for destination in remote_queries
+        ]))
+
+        defer.returnValue({
+            "one_time_keys": json_result,
+            "failures": failures
+        })
+
+    @defer.inlineCallbacks
+    def upload_keys_for_user(self, user_id, device_id, keys):
+        time_now = self.clock.time_msec()
+
+        # TODO: Validate the JSON to make sure it has the right keys.
+        device_keys = keys.get("device_keys", None)
+        if device_keys:
+            logger.info(
+                "Updating device_keys for device %r for user %s at %d",
+                device_id, user_id, time_now
+            )
+            # TODO: Sign the JSON with the server key
+            yield self.store.set_e2e_device_keys(
+                user_id, device_id, time_now,
+                encode_canonical_json(device_keys)
+            )
+
+        one_time_keys = keys.get("one_time_keys", None)
+        if one_time_keys:
+            logger.info(
+                "Adding %d one_time_keys for device %r for user %r at %d",
+                len(one_time_keys), device_id, user_id, time_now
+            )
+            key_list = []
+            for key_id, key_json in one_time_keys.items():
+                algorithm, key_id = key_id.split(":")
+                key_list.append((
+                    algorithm, key_id, encode_canonical_json(key_json)
+                ))
+
+            yield self.store.add_e2e_one_time_keys(
+                user_id, device_id, time_now, key_list
+            )
+
+        # the device should have been registered already, but it may have been
+        # deleted due to a race with a DELETE request. Or we may be using an
+        # old access_token without an associated device_id. Either way, we
+        # need to double-check the device is registered to avoid ending up with
+        # keys without a corresponding device.
+        self.device_handler.check_device_registered(user_id, device_id)
+
+        result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+
+        defer.returnValue({"one_time_key_counts": result})
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 16dbddee03..b047ae2250 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -217,7 +217,7 @@ class PresenceHandler(object):
         is some spurious presence changes that will self-correct.
         """
         logger.info(
-            "Performing _on_shutdown. Persiting %d unpersisted changes",
+            "Performing _on_shutdown. Persisting %d unpersisted changes",
             len(self.user_to_current_state)
         )
 
@@ -234,7 +234,7 @@ class PresenceHandler(object):
         may stack up and slow down shutdown times.
         """
         logger.info(
-            "Performing _persist_unpersisted_changes. Persiting %d unpersisted changes",
+            "Performing _persist_unpersisted_changes. Persisting %d unpersisted changes",
             len(self.unpersisted_users_changes)
         )
 
@@ -625,18 +625,8 @@ class PresenceHandler(object):
         Args:
             hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]`
         """
-        now = self.clock.time_msec()
         for host, states in hosts_to_states.items():
-            self.federation.send_edu(
-                destination=host,
-                edu_type="m.presence",
-                content={
-                    "push": [
-                        _format_user_presence_state(state, now)
-                        for state in states
-                    ]
-                }
-            )
+            self.federation.send_presence(host, states)
 
     @defer.inlineCallbacks
     def incoming_presence(self, origin, content):
@@ -723,13 +713,13 @@ class PresenceHandler(object):
             defer.returnValue([
                 {
                     "type": "m.presence",
-                    "content": _format_user_presence_state(state, now),
+                    "content": format_user_presence_state(state, now),
                 }
                 for state in updates
             ])
         else:
             defer.returnValue([
-                _format_user_presence_state(state, now) for state in updates
+                format_user_presence_state(state, now) for state in updates
             ])
 
     @defer.inlineCallbacks
@@ -988,7 +978,7 @@ def should_notify(old_state, new_state):
     return False
 
 
-def _format_user_presence_state(state, now):
+def format_user_presence_state(state, now):
     """Convert UserPresenceState to a format that can be sent down to clients
     and to other servers.
     """
@@ -1101,7 +1091,7 @@ class PresenceEventSource(object):
         defer.returnValue(([
             {
                 "type": "m.presence",
-                "content": _format_user_presence_state(s, now),
+                "content": format_user_presence_state(s, now),
             }
             for s in updates.values()
             if include_offline or s.state != PresenceState.OFFLINE
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 726f7308d2..e536a909d0 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -156,6 +156,7 @@ class ReceiptsHandler(BaseHandler):
                             }
                         },
                     },
+                    key=(room_id, receipt_type, user_id),
                 )
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 8758af4ca1..d40ada60c1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -196,6 +196,11 @@ class RoomCreationHandler(BaseHandler):
                 },
                 ratelimit=False)
 
+        content = {}
+        is_direct = config.get("is_direct", None)
+        if is_direct:
+            content["is_direct"] = is_direct
+
         for invitee in invite_list:
             yield room_member_handler.update_membership(
                 requester,
@@ -203,6 +208,7 @@ class RoomCreationHandler(BaseHandler):
                 room_id,
                 "invite",
                 ratelimit=False,
+                content=content,
             )
 
         for invite_3pid in invite_3pid_list:
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 3b687957dd..0548b81c34 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -187,6 +187,7 @@ class TypingHandler(object):
                         "user_id": user_id,
                         "typing": typing,
                     },
+                    key=(room_id, user_id),
                 ))
 
         yield preserve_context_over_deferred(