summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/app/synchrotron.py3
-rw-r--r--synapse/federation/federation_client.py22
-rw-r--r--synapse/federation/transaction_queue.py50
-rw-r--r--synapse/federation/transport/client.py6
-rw-r--r--synapse/handlers/e2e_keys.py184
-rw-r--r--synapse/handlers/presence.py24
-rw-r--r--synapse/handlers/receipts.py1
-rw-r--r--synapse/handlers/room.py6
-rw-r--r--synapse/handlers/typing.py1
-rw-r--r--synapse/http/matrixfederationclient.py11
-rw-r--r--synapse/replication/resource.py9
-rw-r--r--synapse/rest/client/v2_alpha/keys.py129
-rw-r--r--synapse/storage/background_updates.py24
-rw-r--r--synapse/storage/deviceinbox.py2
-rw-r--r--synapse/storage/event_push_actions.py9
-rw-r--r--synapse/storage/events.py8
-rw-r--r--synapse/storage/schema/delta/22/receipts_index.sql4
-rw-r--r--synapse/storage/schema/delta/28/events_room_stream.sql4
-rw-r--r--synapse/storage/schema/delta/28/public_roms_index.sql4
-rw-r--r--synapse/storage/schema/delta/28/receipts_user_id_index.sql4
-rw-r--r--synapse/storage/schema/delta/29/push_actions.sql4
-rw-r--r--synapse/storage/schema/delta/31/pushers_index.sql4
-rw-r--r--synapse/storage/schema/delta/35/contains_url.sql17
-rw-r--r--synapse/storage/schema/delta/35/event_push_actions_index.sql5
-rw-r--r--synapse/storage/state.py56
25 files changed, 405 insertions, 186 deletions
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 07d3d047c6..dbaa48035d 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -242,6 +242,9 @@ class SynchrotronTyping(object):
         self._room_typing = {}
 
     def stream_positions(self):
+        # We must update this typing token from the response of the previous
+        # sync. In particular, the stream id may "reset" back to zero/a low
+        # value which we *must* use for the next replication request.
         return {"typing": self._latest_room_serial}
 
     def process_replication(self, result):
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 78719eed25..91bed4746f 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -122,8 +122,12 @@ class FederationClient(FederationBase):
             pdu.event_id
         )
 
+    def send_presence(self, destination, states):
+        if destination != self.server_name:
+            self._transaction_queue.enqueue_presence(destination, states)
+
     @log_function
-    def send_edu(self, destination, edu_type, content):
+    def send_edu(self, destination, edu_type, content, key=None):
         edu = Edu(
             origin=self.server_name,
             destination=destination,
@@ -134,7 +138,7 @@ class FederationClient(FederationBase):
         sent_edus_counter.inc()
 
         # TODO, add errback, etc.
-        self._transaction_queue.enqueue_edu(edu)
+        self._transaction_queue.enqueue_edu(edu, key=key)
         return defer.succeed(None)
 
     @log_function
@@ -172,7 +176,7 @@ class FederationClient(FederationBase):
         )
 
     @log_function
-    def query_client_keys(self, destination, content):
+    def query_client_keys(self, destination, content, timeout):
         """Query device keys for a device hosted on a remote server.
 
         Args:
@@ -184,10 +188,12 @@ class FederationClient(FederationBase):
             response
         """
         sent_queries_counter.inc("client_device_keys")
-        return self.transport_layer.query_client_keys(destination, content)
+        return self.transport_layer.query_client_keys(
+            destination, content, timeout
+        )
 
     @log_function
-    def claim_client_keys(self, destination, content):
+    def claim_client_keys(self, destination, content, timeout):
         """Claims one-time keys for a device hosted on a remote server.
 
         Args:
@@ -199,7 +205,9 @@ class FederationClient(FederationBase):
             response
         """
         sent_queries_counter.inc("client_one_time_keys")
-        return self.transport_layer.claim_client_keys(destination, content)
+        return self.transport_layer.claim_client_keys(
+            destination, content, timeout
+        )
 
     @defer.inlineCallbacks
     @log_function
@@ -477,7 +485,7 @@ class FederationClient(FederationBase):
                 defer.DeferredList(deferreds, consumeErrors=True)
             )
             for success, result in res:
-                if success:
+                if success and result:
                     signed_events.append(result)
                     batch.discard(result.event_id)
 
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 1ac569b305..f8ca93e4c3 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -26,6 +26,7 @@ from synapse.util.retryutils import (
     get_retry_limiter, NotRetryingDestination,
 )
 from synapse.util.metrics import measure_func
+from synapse.handlers.presence import format_user_presence_state
 import synapse.metrics
 
 import logging
@@ -69,13 +70,21 @@ class TransactionQueue(object):
         # destination -> list of tuple(edu, deferred)
         self.pending_edus_by_dest = edus = {}
 
+        # Presence needs to be separate as we send single aggragate EDUs
+        self.pending_presence_by_dest = presence = {}
+        self.pending_edus_keyed_by_dest = edus_keyed = {}
+
         metrics.register_callback(
             "pending_pdus",
             lambda: sum(map(len, pdus.values())),
         )
         metrics.register_callback(
             "pending_edus",
-            lambda: sum(map(len, edus.values())),
+            lambda: (
+                sum(map(len, edus.values()))
+                + sum(map(len, presence.values()))
+                + sum(map(len, edus_keyed.values()))
+            ),
         )
 
         # destination -> list of tuple(failure, deferred)
@@ -130,13 +139,27 @@ class TransactionQueue(object):
                 self._attempt_new_transaction, destination
             )
 
-    def enqueue_edu(self, edu):
+    def enqueue_presence(self, destination, states):
+        self.pending_presence_by_dest.setdefault(destination, {}).update({
+            state.user_id: state for state in states
+        })
+
+        preserve_context_over_fn(
+            self._attempt_new_transaction, destination
+        )
+
+    def enqueue_edu(self, edu, key=None):
         destination = edu.destination
 
         if not self.can_send_to(destination):
             return
 
-        self.pending_edus_by_dest.setdefault(destination, []).append(edu)
+        if key:
+            self.pending_edus_keyed_by_dest.setdefault(
+                destination, {}
+            )[(edu.edu_type, key)] = edu
+        else:
+            self.pending_edus_by_dest.setdefault(destination, []).append(edu)
 
         preserve_context_over_fn(
             self._attempt_new_transaction, destination
@@ -190,8 +213,13 @@ class TransactionQueue(object):
             while True:
                     pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
                     pending_edus = self.pending_edus_by_dest.pop(destination, [])
+                    pending_presence = self.pending_presence_by_dest.pop(destination, {})
                     pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
+                    pending_edus.extend(
+                        self.pending_edus_keyed_by_dest.pop(destination, {}).values()
+                    )
+
                     limiter = yield get_retry_limiter(
                         destination,
                         self.clock,
@@ -203,6 +231,22 @@ class TransactionQueue(object):
                     )
 
                     pending_edus.extend(device_message_edus)
+                    if pending_presence:
+                        pending_edus.append(
+                            Edu(
+                                origin=self.server_name,
+                                destination=destination,
+                                edu_type="m.presence",
+                                content={
+                                    "push": [
+                                        format_user_presence_state(
+                                            presence, self.clock.time_msec()
+                                        )
+                                        for presence in pending_presence.values()
+                                    ]
+                                },
+                            )
+                        )
 
                     if pending_pdus:
                         logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 3d088e43cb..2b138526ba 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -298,7 +298,7 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def query_client_keys(self, destination, query_content):
+    def query_client_keys(self, destination, query_content, timeout):
         """Query the device keys for a list of user ids hosted on a remote
         server.
 
@@ -327,12 +327,13 @@ class TransportLayerClient(object):
             destination=destination,
             path=path,
             data=query_content,
+            timeout=timeout,
         )
         defer.returnValue(content)
 
     @defer.inlineCallbacks
     @log_function
-    def claim_client_keys(self, destination, query_content):
+    def claim_client_keys(self, destination, query_content, timeout):
         """Claim one-time keys for a list of devices hosted on a remote server.
 
         Request:
@@ -363,6 +364,7 @@ class TransportLayerClient(object):
             destination=destination,
             path=path,
             data=query_content,
+            timeout=timeout,
         )
         defer.returnValue(content)
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 2c7bfd91ed..fd11935b40 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -13,14 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import collections
-import json
+import ujson as json
 import logging
 
+from canonicaljson import encode_canonical_json
 from twisted.internet import defer
 
-from synapse.api import errors
-import synapse.types
+from synapse.api.errors import SynapseError, CodeMessageException
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
 
 logger = logging.getLogger(__name__)
 
@@ -29,8 +31,9 @@ class E2eKeysHandler(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
         self.federation = hs.get_replication_layer()
+        self.device_handler = hs.get_device_handler()
         self.is_mine_id = hs.is_mine_id
-        self.server_name = hs.hostname
+        self.clock = hs.get_clock()
 
         # doesn't really work as part of the generic query API, because the
         # query request requires an object POST, but we abuse the
@@ -40,7 +43,7 @@ class E2eKeysHandler(object):
         )
 
     @defer.inlineCallbacks
-    def query_devices(self, query_body):
+    def query_devices(self, query_body, timeout):
         """ Handle a device key query from a client
 
         {
@@ -63,27 +66,60 @@ class E2eKeysHandler(object):
 
         # separate users by domain.
         # make a map from domain to user_id to device_ids
-        queries_by_domain = collections.defaultdict(dict)
+        local_query = {}
+        remote_queries = {}
+
         for user_id, device_ids in device_keys_query.items():
-            user = synapse.types.UserID.from_string(user_id)
-            queries_by_domain[user.domain][user_id] = device_ids
+            if self.is_mine_id(user_id):
+                local_query[user_id] = device_ids
+            else:
+                domain = get_domain_from_id(user_id)
+                remote_queries.setdefault(domain, {})[user_id] = device_ids
 
         # do the queries
-        # TODO: do these in parallel
+        failures = {}
         results = {}
-        for destination, destination_query in queries_by_domain.items():
-            if destination == self.server_name:
-                res = yield self.query_local_devices(destination_query)
-            else:
-                res = yield self.federation.query_client_keys(
-                    destination, {"device_keys": destination_query}
-                )
-                res = res["device_keys"]
-            for user_id, keys in res.items():
-                if user_id in destination_query:
+        if local_query:
+            local_result = yield self.query_local_devices(local_query)
+            for user_id, keys in local_result.items():
+                if user_id in local_query:
                     results[user_id] = keys
 
-        defer.returnValue((200, {"device_keys": results}))
+        @defer.inlineCallbacks
+        def do_remote_query(destination):
+            destination_query = remote_queries[destination]
+            try:
+                limiter = yield get_retry_limiter(
+                    destination, self.clock, self.store
+                )
+                with limiter:
+                    remote_result = yield self.federation.query_client_keys(
+                        destination,
+                        {"device_keys": destination_query},
+                        timeout=timeout
+                    )
+
+                for user_id, keys in remote_result["device_keys"].items():
+                    if user_id in destination_query:
+                        results[user_id] = keys
+
+            except CodeMessageException as e:
+                failures[destination] = {
+                    "status": e.code, "message": e.message
+                }
+            except NotRetryingDestination as e:
+                failures[destination] = {
+                    "status": 503, "message": "Not ready for retry",
+                }
+
+        yield preserve_context_over_deferred(defer.gatherResults([
+            preserve_fn(do_remote_query)(destination)
+            for destination in remote_queries
+        ]))
+
+        defer.returnValue({
+            "device_keys": results, "failures": failures,
+        })
 
     @defer.inlineCallbacks
     def query_local_devices(self, query):
@@ -104,7 +140,7 @@ class E2eKeysHandler(object):
             if not self.is_mine_id(user_id):
                 logger.warning("Request for keys for non-local user %s",
                                user_id)
-                raise errors.SynapseError(400, "Not a user here")
+                raise SynapseError(400, "Not a user here")
 
             if not device_ids:
                 local_query.append((user_id, None))
@@ -137,3 +173,107 @@ class E2eKeysHandler(object):
         device_keys_query = query_body.get("device_keys", {})
         res = yield self.query_local_devices(device_keys_query)
         defer.returnValue({"device_keys": res})
+
+    @defer.inlineCallbacks
+    def claim_one_time_keys(self, query, timeout):
+        local_query = []
+        remote_queries = {}
+
+        for user_id, device_keys in query.get("one_time_keys", {}).items():
+            if self.is_mine_id(user_id):
+                for device_id, algorithm in device_keys.items():
+                    local_query.append((user_id, device_id, algorithm))
+            else:
+                domain = get_domain_from_id(user_id)
+                remote_queries.setdefault(domain, {})[user_id] = device_keys
+
+        results = yield self.store.claim_e2e_one_time_keys(local_query)
+
+        json_result = {}
+        failures = {}
+        for user_id, device_keys in results.items():
+            for device_id, keys in device_keys.items():
+                for key_id, json_bytes in keys.items():
+                    json_result.setdefault(user_id, {})[device_id] = {
+                        key_id: json.loads(json_bytes)
+                    }
+
+        @defer.inlineCallbacks
+        def claim_client_keys(destination):
+            device_keys = remote_queries[destination]
+            try:
+                limiter = yield get_retry_limiter(
+                    destination, self.clock, self.store
+                )
+                with limiter:
+                    remote_result = yield self.federation.claim_client_keys(
+                        destination,
+                        {"one_time_keys": device_keys},
+                        timeout=timeout
+                    )
+                    for user_id, keys in remote_result["one_time_keys"].items():
+                        if user_id in device_keys:
+                            json_result[user_id] = keys
+            except CodeMessageException as e:
+                failures[destination] = {
+                    "status": e.code, "message": e.message
+                }
+            except NotRetryingDestination as e:
+                failures[destination] = {
+                    "status": 503, "message": "Not ready for retry",
+                }
+
+        yield preserve_context_over_deferred(defer.gatherResults([
+            preserve_fn(claim_client_keys)(destination)
+            for destination in remote_queries
+        ]))
+
+        defer.returnValue({
+            "one_time_keys": json_result,
+            "failures": failures
+        })
+
+    @defer.inlineCallbacks
+    def upload_keys_for_user(self, user_id, device_id, keys):
+        time_now = self.clock.time_msec()
+
+        # TODO: Validate the JSON to make sure it has the right keys.
+        device_keys = keys.get("device_keys", None)
+        if device_keys:
+            logger.info(
+                "Updating device_keys for device %r for user %s at %d",
+                device_id, user_id, time_now
+            )
+            # TODO: Sign the JSON with the server key
+            yield self.store.set_e2e_device_keys(
+                user_id, device_id, time_now,
+                encode_canonical_json(device_keys)
+            )
+
+        one_time_keys = keys.get("one_time_keys", None)
+        if one_time_keys:
+            logger.info(
+                "Adding %d one_time_keys for device %r for user %r at %d",
+                len(one_time_keys), device_id, user_id, time_now
+            )
+            key_list = []
+            for key_id, key_json in one_time_keys.items():
+                algorithm, key_id = key_id.split(":")
+                key_list.append((
+                    algorithm, key_id, encode_canonical_json(key_json)
+                ))
+
+            yield self.store.add_e2e_one_time_keys(
+                user_id, device_id, time_now, key_list
+            )
+
+        # the device should have been registered already, but it may have been
+        # deleted due to a race with a DELETE request. Or we may be using an
+        # old access_token without an associated device_id. Either way, we
+        # need to double-check the device is registered to avoid ending up with
+        # keys without a corresponding device.
+        self.device_handler.check_device_registered(user_id, device_id)
+
+        result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+
+        defer.returnValue({"one_time_key_counts": result})
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 16dbddee03..b047ae2250 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -217,7 +217,7 @@ class PresenceHandler(object):
         is some spurious presence changes that will self-correct.
         """
         logger.info(
-            "Performing _on_shutdown. Persiting %d unpersisted changes",
+            "Performing _on_shutdown. Persisting %d unpersisted changes",
             len(self.user_to_current_state)
         )
 
@@ -234,7 +234,7 @@ class PresenceHandler(object):
         may stack up and slow down shutdown times.
         """
         logger.info(
-            "Performing _persist_unpersisted_changes. Persiting %d unpersisted changes",
+            "Performing _persist_unpersisted_changes. Persisting %d unpersisted changes",
             len(self.unpersisted_users_changes)
         )
 
@@ -625,18 +625,8 @@ class PresenceHandler(object):
         Args:
             hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]`
         """
-        now = self.clock.time_msec()
         for host, states in hosts_to_states.items():
-            self.federation.send_edu(
-                destination=host,
-                edu_type="m.presence",
-                content={
-                    "push": [
-                        _format_user_presence_state(state, now)
-                        for state in states
-                    ]
-                }
-            )
+            self.federation.send_presence(host, states)
 
     @defer.inlineCallbacks
     def incoming_presence(self, origin, content):
@@ -723,13 +713,13 @@ class PresenceHandler(object):
             defer.returnValue([
                 {
                     "type": "m.presence",
-                    "content": _format_user_presence_state(state, now),
+                    "content": format_user_presence_state(state, now),
                 }
                 for state in updates
             ])
         else:
             defer.returnValue([
-                _format_user_presence_state(state, now) for state in updates
+                format_user_presence_state(state, now) for state in updates
             ])
 
     @defer.inlineCallbacks
@@ -988,7 +978,7 @@ def should_notify(old_state, new_state):
     return False
 
 
-def _format_user_presence_state(state, now):
+def format_user_presence_state(state, now):
     """Convert UserPresenceState to a format that can be sent down to clients
     and to other servers.
     """
@@ -1101,7 +1091,7 @@ class PresenceEventSource(object):
         defer.returnValue(([
             {
                 "type": "m.presence",
-                "content": _format_user_presence_state(s, now),
+                "content": format_user_presence_state(s, now),
             }
             for s in updates.values()
             if include_offline or s.state != PresenceState.OFFLINE
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 726f7308d2..e536a909d0 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -156,6 +156,7 @@ class ReceiptsHandler(BaseHandler):
                             }
                         },
                     },
+                    key=(room_id, receipt_type, user_id),
                 )
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 8758af4ca1..d40ada60c1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -196,6 +196,11 @@ class RoomCreationHandler(BaseHandler):
                 },
                 ratelimit=False)
 
+        content = {}
+        is_direct = config.get("is_direct", None)
+        if is_direct:
+            content["is_direct"] = is_direct
+
         for invitee in invite_list:
             yield room_member_handler.update_membership(
                 requester,
@@ -203,6 +208,7 @@ class RoomCreationHandler(BaseHandler):
                 room_id,
                 "invite",
                 ratelimit=False,
+                content=content,
             )
 
         for invite_3pid in invite_3pid_list:
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 3b687957dd..0548b81c34 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -187,6 +187,7 @@ class TypingHandler(object):
                         "user_id": user_id,
                         "typing": typing,
                     },
+                    key=(room_id, user_id),
                 ))
 
         yield preserve_context_over_deferred(
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index f93093dd85..d0556ae347 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -246,7 +246,7 @@ class MatrixFederationHttpClient(object):
 
     @defer.inlineCallbacks
     def put_json(self, destination, path, data={}, json_data_callback=None,
-                 long_retries=False):
+                 long_retries=False, timeout=None):
         """ Sends the specifed json data using PUT
 
         Args:
@@ -259,6 +259,8 @@ class MatrixFederationHttpClient(object):
                 use as the request body.
             long_retries (bool): A boolean that indicates whether we should
                 retry for a short or long time.
+            timeout(int): How long to try (in ms) the destination for before
+                giving up. None indicates no timeout.
 
         Returns:
             Deferred: Succeeds when we get a 2xx HTTP response. The result
@@ -285,6 +287,7 @@ class MatrixFederationHttpClient(object):
             body_callback=body_callback,
             headers_dict={"Content-Type": ["application/json"]},
             long_retries=long_retries,
+            timeout=timeout,
         )
 
         if 200 <= response.code < 300:
@@ -300,7 +303,8 @@ class MatrixFederationHttpClient(object):
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
-    def post_json(self, destination, path, data={}, long_retries=True):
+    def post_json(self, destination, path, data={}, long_retries=True,
+                  timeout=None):
         """ Sends the specifed json data using POST
 
         Args:
@@ -311,6 +315,8 @@ class MatrixFederationHttpClient(object):
                 the request body. This will be encoded as JSON.
             long_retries (bool): A boolean that indicates whether we should
                 retry for a short or long time.
+            timeout(int): How long to try (in ms) the destination for before
+                giving up. None indicates no timeout.
 
         Returns:
             Deferred: Succeeds when we get a 2xx HTTP response. The result
@@ -331,6 +337,7 @@ class MatrixFederationHttpClient(object):
             body_callback=body_callback,
             headers_dict={"Content-Type": ["application/json"]},
             long_retries=True,
+            timeout=timeout,
         )
 
         if 200 <= response.code < 300:
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index 857bc9795c..299e9419a4 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -274,11 +274,18 @@ class ReplicationResource(Resource):
 
     @defer.inlineCallbacks
     def typing(self, writer, current_token, request_streams):
-        current_position = current_token.presence
+        current_position = current_token.typing
 
         request_typing = request_streams.get("typing")
 
         if request_typing is not None:
+            # If they have a higher token than current max, we can assume that
+            # they had been talking to a previous instance of the master. Since
+            # we reset the token on restart, the best (but hacky) thing we can
+            # do is to simply resend down all the typing notifications.
+            if request_typing > current_position:
+                request_typing = 0
+
             typing_rows = yield self.typing_handler.get_all_typing_updates(
                 request_typing, current_position
             )
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index c5ff16adf3..f185f9a774 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -15,15 +15,12 @@
 
 import logging
 
-import simplejson as json
-from canonicaljson import encode_canonical_json
 from twisted.internet import defer
 
-import synapse.api.errors
-import synapse.server
-import synapse.types
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import UserID
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import (
+    RestServlet, parse_json_object_from_request, parse_integer
+)
 from ._base import client_v2_patterns
 
 logger = logging.getLogger(__name__)
@@ -63,17 +60,13 @@ class KeyUploadServlet(RestServlet):
             hs (synapse.server.HomeServer): server
         """
         super(KeyUploadServlet, self).__init__()
-        self.store = hs.get_datastore()
-        self.clock = hs.get_clock()
         self.auth = hs.get_auth()
-        self.device_handler = hs.get_device_handler()
+        self.e2e_keys_handler = hs.get_e2e_keys_handler()
 
     @defer.inlineCallbacks
     def on_POST(self, request, device_id):
         requester = yield self.auth.get_user_by_req(request)
-
         user_id = requester.user.to_string()
-
         body = parse_json_object_from_request(request)
 
         if device_id is not None:
@@ -88,52 +81,15 @@ class KeyUploadServlet(RestServlet):
             device_id = requester.device_id
 
         if device_id is None:
-            raise synapse.api.errors.SynapseError(
+            raise SynapseError(
                 400,
                 "To upload keys, you must pass device_id when authenticating"
             )
 
-        time_now = self.clock.time_msec()
-
-        # TODO: Validate the JSON to make sure it has the right keys.
-        device_keys = body.get("device_keys", None)
-        if device_keys:
-            logger.info(
-                "Updating device_keys for device %r for user %s at %d",
-                device_id, user_id, time_now
-            )
-            # TODO: Sign the JSON with the server key
-            yield self.store.set_e2e_device_keys(
-                user_id, device_id, time_now,
-                encode_canonical_json(device_keys)
-            )
-
-        one_time_keys = body.get("one_time_keys", None)
-        if one_time_keys:
-            logger.info(
-                "Adding %d one_time_keys for device %r for user %r at %d",
-                len(one_time_keys), device_id, user_id, time_now
-            )
-            key_list = []
-            for key_id, key_json in one_time_keys.items():
-                algorithm, key_id = key_id.split(":")
-                key_list.append((
-                    algorithm, key_id, encode_canonical_json(key_json)
-                ))
-
-            yield self.store.add_e2e_one_time_keys(
-                user_id, device_id, time_now, key_list
-            )
-
-        # the device should have been registered already, but it may have been
-        # deleted due to a race with a DELETE request. Or we may be using an
-        # old access_token without an associated device_id. Either way, we
-        # need to double-check the device is registered to avoid ending up with
-        # keys without a corresponding device.
-        self.device_handler.check_device_registered(user_id, device_id)
-
-        result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
-        defer.returnValue((200, {"one_time_key_counts": result}))
+        result = yield self.e2e_keys_handler.upload_keys_for_user(
+            user_id, device_id, body
+        )
+        defer.returnValue((200, result))
 
 
 class KeyQueryServlet(RestServlet):
@@ -195,20 +151,23 @@ class KeyQueryServlet(RestServlet):
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id):
         yield self.auth.get_user_by_req(request)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
-        result = yield self.e2e_keys_handler.query_devices(body)
-        defer.returnValue(result)
+        result = yield self.e2e_keys_handler.query_devices(body, timeout)
+        defer.returnValue((200, result))
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id):
         requester = yield self.auth.get_user_by_req(request)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
         auth_user_id = requester.user.to_string()
         user_id = user_id if user_id else auth_user_id
         device_ids = [device_id] if device_id else []
         result = yield self.e2e_keys_handler.query_devices(
-            {"device_keys": {user_id: device_ids}}
+            {"device_keys": {user_id: device_ids}},
+            timeout,
         )
-        defer.returnValue(result)
+        defer.returnValue((200, result))
 
 
 class OneTimeKeyServlet(RestServlet):
@@ -240,59 +199,29 @@ class OneTimeKeyServlet(RestServlet):
 
     def __init__(self, hs):
         super(OneTimeKeyServlet, self).__init__()
-        self.store = hs.get_datastore()
         self.auth = hs.get_auth()
-        self.clock = hs.get_clock()
-        self.federation = hs.get_replication_layer()
-        self.is_mine = hs.is_mine
+        self.e2e_keys_handler = hs.get_e2e_keys_handler()
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id, device_id, algorithm):
         yield self.auth.get_user_by_req(request)
-        result = yield self.handle_request(
-            {"one_time_keys": {user_id: {device_id: algorithm}}}
+        timeout = parse_integer(request, "timeout", 10 * 1000)
+        result = yield self.e2e_keys_handler.claim_one_time_keys(
+            {"one_time_keys": {user_id: {device_id: algorithm}}},
+            timeout,
         )
-        defer.returnValue(result)
+        defer.returnValue((200, result))
 
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id, algorithm):
         yield self.auth.get_user_by_req(request)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
-        result = yield self.handle_request(body)
-        defer.returnValue(result)
-
-    @defer.inlineCallbacks
-    def handle_request(self, body):
-        local_query = []
-        remote_queries = {}
-        for user_id, device_keys in body.get("one_time_keys", {}).items():
-            user = UserID.from_string(user_id)
-            if self.is_mine(user):
-                for device_id, algorithm in device_keys.items():
-                    local_query.append((user_id, device_id, algorithm))
-            else:
-                remote_queries.setdefault(user.domain, {})[user_id] = (
-                    device_keys
-                )
-        results = yield self.store.claim_e2e_one_time_keys(local_query)
-
-        json_result = {}
-        for user_id, device_keys in results.items():
-            for device_id, keys in device_keys.items():
-                for key_id, json_bytes in keys.items():
-                    json_result.setdefault(user_id, {})[device_id] = {
-                        key_id: json.loads(json_bytes)
-                    }
-
-        for destination, device_keys in remote_queries.items():
-            remote_result = yield self.federation.claim_client_keys(
-                destination, {"one_time_keys": device_keys}
-            )
-            for user_id, keys in remote_result["one_time_keys"].items():
-                if user_id in device_keys:
-                    json_result[user_id] = keys
-
-        defer.returnValue((200, {"one_time_keys": json_result}))
+        result = yield self.e2e_keys_handler.claim_one_time_keys(
+            body,
+            timeout,
+        )
+        defer.returnValue((200, result))
 
 
 def register_servlets(hs, http_server):
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 003f5ba203..94b2bcc54a 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -219,7 +219,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         self._background_update_handlers[update_name] = update_handler
 
     def register_background_index_update(self, update_name, index_name,
-                                         table, columns):
+                                         table, columns, where_clause=None):
         """Helper for store classes to do a background index addition
 
         To use:
@@ -243,14 +243,20 @@ class BackgroundUpdateStore(SQLBaseStore):
             conc = True
         else:
             conc = False
-
-        sql = "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" \
-              % {
-                  "conc": "CONCURRENTLY" if conc else "",
-                  "name": index_name,
-                  "table": table,
-                  "columns": ", ".join(columns),
-              }
+            # We don't use partial indices on SQLite as it wasn't introduced
+            # until 3.8, and wheezy has 3.7
+            where_clause = None
+
+        sql = (
+            "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)"
+            " %(where_clause)s"
+        ) % {
+            "conc": "CONCURRENTLY" if conc else "",
+            "name": index_name,
+            "table": table,
+            "columns": ", ".join(columns),
+            "where_clause": "WHERE " + where_clause if where_clause else ""
+        }
 
         def create_index_concurrently(conn):
             conn.rollback()
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index b729b7106e..f640e73714 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -128,6 +128,8 @@ class DeviceInboxStore(SQLBaseStore):
                     user_id, stream_id
                 )
 
+        defer.returnValue(stream_id)
+
     def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
                                                 messages_by_user_then_device):
         sql = (
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index a87d90741a..9cd923eb93 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -26,10 +26,19 @@ logger = logging.getLogger(__name__)
 
 
 class EventPushActionsStore(SQLBaseStore):
+    EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+
     def __init__(self, hs):
         self.stream_ordering_month_ago = None
         super(EventPushActionsStore, self).__init__(hs)
 
+        self.register_background_index_update(
+            self.EPA_HIGHLIGHT_INDEX,
+            index_name="event_push_actions_u_highlight",
+            table="event_push_actions",
+            columns=["user_id", "stream_ordering"],
+        )
+
     def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
         """
         Args:
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index ed182c8d11..6dc46fa50f 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -189,6 +189,14 @@ class EventsStore(SQLBaseStore):
             self._background_reindex_fields_sender,
         )
 
+        self.register_background_index_update(
+            "event_contains_url_index",
+            index_name="event_contains_url_index",
+            table="events",
+            columns=["room_id", "topological_ordering", "stream_ordering"],
+            where_clause="contains_url = true AND outlier = false",
+        )
+
         self._event_persist_queue = _EventPeristenceQueue()
 
     def persist_events(self, events_and_contexts, backfilled=False):
diff --git a/synapse/storage/schema/delta/22/receipts_index.sql b/synapse/storage/schema/delta/22/receipts_index.sql
index 7bc061dff6..bfc0b3bcaa 100644
--- a/synapse/storage/schema/delta/22/receipts_index.sql
+++ b/synapse/storage/schema/delta/22/receipts_index.sql
@@ -13,6 +13,10 @@
  * limitations under the License.
  */
 
+/** Using CREATE INDEX directly is deprecated in favour of using background
+ * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql
+ * and synapse/storage/registration.py for an example using
+ * "access_tokens_device_index" **/
 CREATE INDEX receipts_linearized_room_stream ON receipts_linearized(
     room_id, stream_id
 );
diff --git a/synapse/storage/schema/delta/28/events_room_stream.sql b/synapse/storage/schema/delta/28/events_room_stream.sql
index 200c35e6e2..36609475f1 100644
--- a/synapse/storage/schema/delta/28/events_room_stream.sql
+++ b/synapse/storage/schema/delta/28/events_room_stream.sql
@@ -13,4 +13,8 @@
  * limitations under the License.
 */
 
+/** Using CREATE INDEX directly is deprecated in favour of using background
+ * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql
+ * and synapse/storage/registration.py for an example using
+ * "access_tokens_device_index" **/
 CREATE INDEX events_room_stream on events(room_id, stream_ordering);
diff --git a/synapse/storage/schema/delta/28/public_roms_index.sql b/synapse/storage/schema/delta/28/public_roms_index.sql
index ba62a974a4..6c1fd68c5b 100644
--- a/synapse/storage/schema/delta/28/public_roms_index.sql
+++ b/synapse/storage/schema/delta/28/public_roms_index.sql
@@ -13,4 +13,8 @@
  * limitations under the License.
 */
 
+/** Using CREATE INDEX directly is deprecated in favour of using background
+ * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql
+ * and synapse/storage/registration.py for an example using
+ * "access_tokens_device_index" **/
 CREATE INDEX public_room_index on rooms(is_public);
diff --git a/synapse/storage/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/schema/delta/28/receipts_user_id_index.sql
index 452a1b3c6c..cb84c69baa 100644
--- a/synapse/storage/schema/delta/28/receipts_user_id_index.sql
+++ b/synapse/storage/schema/delta/28/receipts_user_id_index.sql
@@ -13,6 +13,10 @@
  * limitations under the License.
  */
 
+/** Using CREATE INDEX directly is deprecated in favour of using background
+ * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql
+ * and synapse/storage/registration.py for an example using
+ * "access_tokens_device_index" **/
 CREATE INDEX receipts_linearized_user ON receipts_linearized(
     user_id
 );
diff --git a/synapse/storage/schema/delta/29/push_actions.sql b/synapse/storage/schema/delta/29/push_actions.sql
index 7e7b09820a..84b21cf813 100644
--- a/synapse/storage/schema/delta/29/push_actions.sql
+++ b/synapse/storage/schema/delta/29/push_actions.sql
@@ -26,6 +26,10 @@ UPDATE event_push_actions SET stream_ordering = (
 
 UPDATE event_push_actions SET notif = 1, highlight = 0;
 
+/** Using CREATE INDEX directly is deprecated in favour of using background
+ * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql
+ * and synapse/storage/registration.py for an example using
+ * "access_tokens_device_index" **/
 CREATE INDEX event_push_actions_rm_tokens on event_push_actions(
     user_id, room_id, topological_ordering, stream_ordering
 );
diff --git a/synapse/storage/schema/delta/31/pushers_index.sql b/synapse/storage/schema/delta/31/pushers_index.sql
index 9027bccc69..a82add88fd 100644
--- a/synapse/storage/schema/delta/31/pushers_index.sql
+++ b/synapse/storage/schema/delta/31/pushers_index.sql
@@ -13,6 +13,10 @@
  * limitations under the License.
  */
 
+/** Using CREATE INDEX directly is deprecated in favour of using background
+ * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql
+ * and synapse/storage/registration.py for an example using
+ * "access_tokens_device_index" **/
  CREATE INDEX event_push_actions_stream_ordering on event_push_actions(
      stream_ordering, user_id
  );
diff --git a/synapse/storage/schema/delta/35/contains_url.sql b/synapse/storage/schema/delta/35/contains_url.sql
new file mode 100644
index 0000000000..6cd123027b
--- /dev/null
+++ b/synapse/storage/schema/delta/35/contains_url.sql
@@ -0,0 +1,17 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ INSERT into background_updates (update_name, progress_json)
+     VALUES ('event_contains_url_index', '{}');
diff --git a/synapse/storage/schema/delta/35/event_push_actions_index.sql b/synapse/storage/schema/delta/35/event_push_actions_index.sql
index 4fc32c351a..2e836d8e9c 100644
--- a/synapse/storage/schema/delta/35/event_push_actions_index.sql
+++ b/synapse/storage/schema/delta/35/event_push_actions_index.sql
@@ -13,6 +13,5 @@
  * limitations under the License.
  */
 
- CREATE INDEX event_push_actions_user_id_highlight_stream_ordering on event_push_actions(
-     user_id, highlight, stream_ordering
- );
+ INSERT into background_updates (update_name, progress_json)
+     VALUES ('epa_highlight_index', '{}');
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0cff0a0cda..fdbdade536 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -306,13 +306,6 @@ class StateStore(SQLBaseStore):
         defer.returnValue(results)
 
     def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
-        if types is not None:
-            where_clause = "AND (%s)" % (
-                " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
-            )
-        else:
-            where_clause = ""
-
         results = {group: {} for group in groups}
         if isinstance(self.database_engine, PostgresEngine):
             # Temporarily disable sequential scans in this transaction. This is
@@ -342,20 +335,43 @@ class StateStore(SQLBaseStore):
                 WHERE state_group IN (
                     SELECT state_group FROM state
                 )
-                %s;
-            """) % (where_clause,)
-
-            for group in groups:
-                args = [group]
-                if types is not None:
-                    args.extend([i for typ in types for i in typ])
+                %s
+            """)
 
-                txn.execute(sql, args)
-                rows = self.cursor_to_dict(txn)
-                for row in rows:
-                    key = (row["type"], row["state_key"])
-                    results[group][key] = row["event_id"]
+            # Turns out that postgres doesn't like doing a list of OR's and
+            # is about 1000x slower, so we just issue a query for each specific
+            # type seperately.
+            if types:
+                clause_to_args = [
+                    (
+                        "AND type = ? AND state_key = ?",
+                        (etype, state_key)
+                    )
+                    for etype, state_key in types
+                ]
+            else:
+                # If types is None we fetch all the state, and so just use an
+                # empty where clause with no extra args.
+                clause_to_args = [("", [])]
+
+            for where_clause, where_args in clause_to_args:
+                for group in groups:
+                    args = [group]
+                    args.extend(where_args)
+
+                    txn.execute(sql % (where_clause,), args)
+                    rows = self.cursor_to_dict(txn)
+                    for row in rows:
+                        key = (row["type"], row["state_key"])
+                        results[group][key] = row["event_id"]
         else:
+            if types is not None:
+                where_clause = "AND (%s)" % (
+                    " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
+                )
+            else:
+                where_clause = ""
+
             # We don't use WITH RECURSIVE on sqlite3 as there are distributions
             # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
             for group in groups:
@@ -804,7 +820,7 @@ class StateStore(SQLBaseStore):
         def reindex_txn(txn):
             if isinstance(self.database_engine, PostgresEngine):
                 txn.execute(
-                    "CREATE INDEX state_groups_state_type_idx"
+                    "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
                     " ON state_groups_state(state_group, type, state_key)"
                 )
                 txn.execute(