diff options
25 files changed, 405 insertions, 186 deletions
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 07d3d047c6..dbaa48035d 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -242,6 +242,9 @@ class SynchrotronTyping(object): self._room_typing = {} def stream_positions(self): + # We must update this typing token from the response of the previous + # sync. In particular, the stream id may "reset" back to zero/a low + # value which we *must* use for the next replication request. return {"typing": self._latest_room_serial} def process_replication(self, result): diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 78719eed25..91bed4746f 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -122,8 +122,12 @@ class FederationClient(FederationBase): pdu.event_id ) + def send_presence(self, destination, states): + if destination != self.server_name: + self._transaction_queue.enqueue_presence(destination, states) + @log_function - def send_edu(self, destination, edu_type, content): + def send_edu(self, destination, edu_type, content, key=None): edu = Edu( origin=self.server_name, destination=destination, @@ -134,7 +138,7 @@ class FederationClient(FederationBase): sent_edus_counter.inc() # TODO, add errback, etc. - self._transaction_queue.enqueue_edu(edu) + self._transaction_queue.enqueue_edu(edu, key=key) return defer.succeed(None) @log_function @@ -172,7 +176,7 @@ class FederationClient(FederationBase): ) @log_function - def query_client_keys(self, destination, content): + def query_client_keys(self, destination, content, timeout): """Query device keys for a device hosted on a remote server. Args: @@ -184,10 +188,12 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_device_keys") - return self.transport_layer.query_client_keys(destination, content) + return self.transport_layer.query_client_keys( + destination, content, timeout + ) @log_function - def claim_client_keys(self, destination, content): + def claim_client_keys(self, destination, content, timeout): """Claims one-time keys for a device hosted on a remote server. Args: @@ -199,7 +205,9 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_one_time_keys") - return self.transport_layer.claim_client_keys(destination, content) + return self.transport_layer.claim_client_keys( + destination, content, timeout + ) @defer.inlineCallbacks @log_function @@ -477,7 +485,7 @@ class FederationClient(FederationBase): defer.DeferredList(deferreds, consumeErrors=True) ) for success, result in res: - if success: + if success and result: signed_events.append(result) batch.discard(result.event_id) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 1ac569b305..f8ca93e4c3 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -26,6 +26,7 @@ from synapse.util.retryutils import ( get_retry_limiter, NotRetryingDestination, ) from synapse.util.metrics import measure_func +from synapse.handlers.presence import format_user_presence_state import synapse.metrics import logging @@ -69,13 +70,21 @@ class TransactionQueue(object): # destination -> list of tuple(edu, deferred) self.pending_edus_by_dest = edus = {} + # Presence needs to be separate as we send single aggragate EDUs + self.pending_presence_by_dest = presence = {} + self.pending_edus_keyed_by_dest = edus_keyed = {} + metrics.register_callback( "pending_pdus", lambda: sum(map(len, pdus.values())), ) metrics.register_callback( "pending_edus", - lambda: sum(map(len, edus.values())), + lambda: ( + sum(map(len, edus.values())) + + sum(map(len, presence.values())) + + sum(map(len, edus_keyed.values())) + ), ) # destination -> list of tuple(failure, deferred) @@ -130,13 +139,27 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_edu(self, edu): + def enqueue_presence(self, destination, states): + self.pending_presence_by_dest.setdefault(destination, {}).update({ + state.user_id: state for state in states + }) + + preserve_context_over_fn( + self._attempt_new_transaction, destination + ) + + def enqueue_edu(self, edu, key=None): destination = edu.destination if not self.can_send_to(destination): return - self.pending_edus_by_dest.setdefault(destination, []).append(edu) + if key: + self.pending_edus_keyed_by_dest.setdefault( + destination, {} + )[(edu.edu_type, key)] = edu + else: + self.pending_edus_by_dest.setdefault(destination, []).append(edu) preserve_context_over_fn( self._attempt_new_transaction, destination @@ -190,8 +213,13 @@ class TransactionQueue(object): while True: pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_presence = self.pending_presence_by_dest.pop(destination, {}) pending_failures = self.pending_failures_by_dest.pop(destination, []) + pending_edus.extend( + self.pending_edus_keyed_by_dest.pop(destination, {}).values() + ) + limiter = yield get_retry_limiter( destination, self.clock, @@ -203,6 +231,22 @@ class TransactionQueue(object): ) pending_edus.extend(device_message_edus) + if pending_presence: + pending_edus.append( + Edu( + origin=self.server_name, + destination=destination, + edu_type="m.presence", + content={ + "push": [ + format_user_presence_state( + presence, self.clock.time_msec() + ) + for presence in pending_presence.values() + ] + }, + ) + ) if pending_pdus: logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 3d088e43cb..2b138526ba 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -298,7 +298,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def query_client_keys(self, destination, query_content): + def query_client_keys(self, destination, query_content, timeout): """Query the device keys for a list of user ids hosted on a remote server. @@ -327,12 +327,13 @@ class TransportLayerClient(object): destination=destination, path=path, data=query_content, + timeout=timeout, ) defer.returnValue(content) @defer.inlineCallbacks @log_function - def claim_client_keys(self, destination, query_content): + def claim_client_keys(self, destination, query_content, timeout): """Claim one-time keys for a list of devices hosted on a remote server. Request: @@ -363,6 +364,7 @@ class TransportLayerClient(object): destination=destination, path=path, data=query_content, + timeout=timeout, ) defer.returnValue(content) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 2c7bfd91ed..fd11935b40 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -13,14 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections -import json +import ujson as json import logging +from canonicaljson import encode_canonical_json from twisted.internet import defer -from synapse.api import errors -import synapse.types +from synapse.api.errors import SynapseError, CodeMessageException +from synapse.types import get_domain_from_id +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination logger = logging.getLogger(__name__) @@ -29,8 +31,9 @@ class E2eKeysHandler(object): def __init__(self, hs): self.store = hs.get_datastore() self.federation = hs.get_replication_layer() + self.device_handler = hs.get_device_handler() self.is_mine_id = hs.is_mine_id - self.server_name = hs.hostname + self.clock = hs.get_clock() # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the @@ -40,7 +43,7 @@ class E2eKeysHandler(object): ) @defer.inlineCallbacks - def query_devices(self, query_body): + def query_devices(self, query_body, timeout): """ Handle a device key query from a client { @@ -63,27 +66,60 @@ class E2eKeysHandler(object): # separate users by domain. # make a map from domain to user_id to device_ids - queries_by_domain = collections.defaultdict(dict) + local_query = {} + remote_queries = {} + for user_id, device_ids in device_keys_query.items(): - user = synapse.types.UserID.from_string(user_id) - queries_by_domain[user.domain][user_id] = device_ids + if self.is_mine_id(user_id): + local_query[user_id] = device_ids + else: + domain = get_domain_from_id(user_id) + remote_queries.setdefault(domain, {})[user_id] = device_ids # do the queries - # TODO: do these in parallel + failures = {} results = {} - for destination, destination_query in queries_by_domain.items(): - if destination == self.server_name: - res = yield self.query_local_devices(destination_query) - else: - res = yield self.federation.query_client_keys( - destination, {"device_keys": destination_query} - ) - res = res["device_keys"] - for user_id, keys in res.items(): - if user_id in destination_query: + if local_query: + local_result = yield self.query_local_devices(local_query) + for user_id, keys in local_result.items(): + if user_id in local_query: results[user_id] = keys - defer.returnValue((200, {"device_keys": results})) + @defer.inlineCallbacks + def do_remote_query(destination): + destination_query = remote_queries[destination] + try: + limiter = yield get_retry_limiter( + destination, self.clock, self.store + ) + with limiter: + remote_result = yield self.federation.query_client_keys( + destination, + {"device_keys": destination_query}, + timeout=timeout + ) + + for user_id, keys in remote_result["device_keys"].items(): + if user_id in destination_query: + results[user_id] = keys + + except CodeMessageException as e: + failures[destination] = { + "status": e.code, "message": e.message + } + except NotRetryingDestination as e: + failures[destination] = { + "status": 503, "message": "Not ready for retry", + } + + yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(do_remote_query)(destination) + for destination in remote_queries + ])) + + defer.returnValue({ + "device_keys": results, "failures": failures, + }) @defer.inlineCallbacks def query_local_devices(self, query): @@ -104,7 +140,7 @@ class E2eKeysHandler(object): if not self.is_mine_id(user_id): logger.warning("Request for keys for non-local user %s", user_id) - raise errors.SynapseError(400, "Not a user here") + raise SynapseError(400, "Not a user here") if not device_ids: local_query.append((user_id, None)) @@ -137,3 +173,107 @@ class E2eKeysHandler(object): device_keys_query = query_body.get("device_keys", {}) res = yield self.query_local_devices(device_keys_query) defer.returnValue({"device_keys": res}) + + @defer.inlineCallbacks + def claim_one_time_keys(self, query, timeout): + local_query = [] + remote_queries = {} + + for user_id, device_keys in query.get("one_time_keys", {}).items(): + if self.is_mine_id(user_id): + for device_id, algorithm in device_keys.items(): + local_query.append((user_id, device_id, algorithm)) + else: + domain = get_domain_from_id(user_id) + remote_queries.setdefault(domain, {})[user_id] = device_keys + + results = yield self.store.claim_e2e_one_time_keys(local_query) + + json_result = {} + failures = {} + for user_id, device_keys in results.items(): + for device_id, keys in device_keys.items(): + for key_id, json_bytes in keys.items(): + json_result.setdefault(user_id, {})[device_id] = { + key_id: json.loads(json_bytes) + } + + @defer.inlineCallbacks + def claim_client_keys(destination): + device_keys = remote_queries[destination] + try: + limiter = yield get_retry_limiter( + destination, self.clock, self.store + ) + with limiter: + remote_result = yield self.federation.claim_client_keys( + destination, + {"one_time_keys": device_keys}, + timeout=timeout + ) + for user_id, keys in remote_result["one_time_keys"].items(): + if user_id in device_keys: + json_result[user_id] = keys + except CodeMessageException as e: + failures[destination] = { + "status": e.code, "message": e.message + } + except NotRetryingDestination as e: + failures[destination] = { + "status": 503, "message": "Not ready for retry", + } + + yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(claim_client_keys)(destination) + for destination in remote_queries + ])) + + defer.returnValue({ + "one_time_keys": json_result, + "failures": failures + }) + + @defer.inlineCallbacks + def upload_keys_for_user(self, user_id, device_id, keys): + time_now = self.clock.time_msec() + + # TODO: Validate the JSON to make sure it has the right keys. + device_keys = keys.get("device_keys", None) + if device_keys: + logger.info( + "Updating device_keys for device %r for user %s at %d", + device_id, user_id, time_now + ) + # TODO: Sign the JSON with the server key + yield self.store.set_e2e_device_keys( + user_id, device_id, time_now, + encode_canonical_json(device_keys) + ) + + one_time_keys = keys.get("one_time_keys", None) + if one_time_keys: + logger.info( + "Adding %d one_time_keys for device %r for user %r at %d", + len(one_time_keys), device_id, user_id, time_now + ) + key_list = [] + for key_id, key_json in one_time_keys.items(): + algorithm, key_id = key_id.split(":") + key_list.append(( + algorithm, key_id, encode_canonical_json(key_json) + )) + + yield self.store.add_e2e_one_time_keys( + user_id, device_id, time_now, key_list + ) + + # the device should have been registered already, but it may have been + # deleted due to a race with a DELETE request. Or we may be using an + # old access_token without an associated device_id. Either way, we + # need to double-check the device is registered to avoid ending up with + # keys without a corresponding device. + self.device_handler.check_device_registered(user_id, device_id) + + result = yield self.store.count_e2e_one_time_keys(user_id, device_id) + + defer.returnValue({"one_time_key_counts": result}) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 16dbddee03..b047ae2250 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -217,7 +217,7 @@ class PresenceHandler(object): is some spurious presence changes that will self-correct. """ logger.info( - "Performing _on_shutdown. Persiting %d unpersisted changes", + "Performing _on_shutdown. Persisting %d unpersisted changes", len(self.user_to_current_state) ) @@ -234,7 +234,7 @@ class PresenceHandler(object): may stack up and slow down shutdown times. """ logger.info( - "Performing _persist_unpersisted_changes. Persiting %d unpersisted changes", + "Performing _persist_unpersisted_changes. Persisting %d unpersisted changes", len(self.unpersisted_users_changes) ) @@ -625,18 +625,8 @@ class PresenceHandler(object): Args: hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]` """ - now = self.clock.time_msec() for host, states in hosts_to_states.items(): - self.federation.send_edu( - destination=host, - edu_type="m.presence", - content={ - "push": [ - _format_user_presence_state(state, now) - for state in states - ] - } - ) + self.federation.send_presence(host, states) @defer.inlineCallbacks def incoming_presence(self, origin, content): @@ -723,13 +713,13 @@ class PresenceHandler(object): defer.returnValue([ { "type": "m.presence", - "content": _format_user_presence_state(state, now), + "content": format_user_presence_state(state, now), } for state in updates ]) else: defer.returnValue([ - _format_user_presence_state(state, now) for state in updates + format_user_presence_state(state, now) for state in updates ]) @defer.inlineCallbacks @@ -988,7 +978,7 @@ def should_notify(old_state, new_state): return False -def _format_user_presence_state(state, now): +def format_user_presence_state(state, now): """Convert UserPresenceState to a format that can be sent down to clients and to other servers. """ @@ -1101,7 +1091,7 @@ class PresenceEventSource(object): defer.returnValue(([ { "type": "m.presence", - "content": _format_user_presence_state(s, now), + "content": format_user_presence_state(s, now), } for s in updates.values() if include_offline or s.state != PresenceState.OFFLINE diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 726f7308d2..e536a909d0 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -156,6 +156,7 @@ class ReceiptsHandler(BaseHandler): } }, }, + key=(room_id, receipt_type, user_id), ) @defer.inlineCallbacks diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 8758af4ca1..d40ada60c1 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -196,6 +196,11 @@ class RoomCreationHandler(BaseHandler): }, ratelimit=False) + content = {} + is_direct = config.get("is_direct", None) + if is_direct: + content["is_direct"] = is_direct + for invitee in invite_list: yield room_member_handler.update_membership( requester, @@ -203,6 +208,7 @@ class RoomCreationHandler(BaseHandler): room_id, "invite", ratelimit=False, + content=content, ) for invite_3pid in invite_3pid_list: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 3b687957dd..0548b81c34 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -187,6 +187,7 @@ class TypingHandler(object): "user_id": user_id, "typing": typing, }, + key=(room_id, user_id), )) yield preserve_context_over_deferred( diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index f93093dd85..d0556ae347 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -246,7 +246,7 @@ class MatrixFederationHttpClient(object): @defer.inlineCallbacks def put_json(self, destination, path, data={}, json_data_callback=None, - long_retries=False): + long_retries=False, timeout=None): """ Sends the specifed json data using PUT Args: @@ -259,6 +259,8 @@ class MatrixFederationHttpClient(object): use as the request body. long_retries (bool): A boolean that indicates whether we should retry for a short or long time. + timeout(int): How long to try (in ms) the destination for before + giving up. None indicates no timeout. Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result @@ -285,6 +287,7 @@ class MatrixFederationHttpClient(object): body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, long_retries=long_retries, + timeout=timeout, ) if 200 <= response.code < 300: @@ -300,7 +303,8 @@ class MatrixFederationHttpClient(object): defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def post_json(self, destination, path, data={}, long_retries=True): + def post_json(self, destination, path, data={}, long_retries=True, + timeout=None): """ Sends the specifed json data using POST Args: @@ -311,6 +315,8 @@ class MatrixFederationHttpClient(object): the request body. This will be encoded as JSON. long_retries (bool): A boolean that indicates whether we should retry for a short or long time. + timeout(int): How long to try (in ms) the destination for before + giving up. None indicates no timeout. Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result @@ -331,6 +337,7 @@ class MatrixFederationHttpClient(object): body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, long_retries=True, + timeout=timeout, ) if 200 <= response.code < 300: diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 857bc9795c..299e9419a4 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -274,11 +274,18 @@ class ReplicationResource(Resource): @defer.inlineCallbacks def typing(self, writer, current_token, request_streams): - current_position = current_token.presence + current_position = current_token.typing request_typing = request_streams.get("typing") if request_typing is not None: + # If they have a higher token than current max, we can assume that + # they had been talking to a previous instance of the master. Since + # we reset the token on restart, the best (but hacky) thing we can + # do is to simply resend down all the typing notifications. + if request_typing > current_position: + request_typing = 0 + typing_rows = yield self.typing_handler.get_all_typing_updates( request_typing, current_position ) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index c5ff16adf3..f185f9a774 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -15,15 +15,12 @@ import logging -import simplejson as json -from canonicaljson import encode_canonical_json from twisted.internet import defer -import synapse.api.errors -import synapse.server -import synapse.types -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.types import UserID +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, parse_json_object_from_request, parse_integer +) from ._base import client_v2_patterns logger = logging.getLogger(__name__) @@ -63,17 +60,13 @@ class KeyUploadServlet(RestServlet): hs (synapse.server.HomeServer): server """ super(KeyUploadServlet, self).__init__() - self.store = hs.get_datastore() - self.clock = hs.get_clock() self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + self.e2e_keys_handler = hs.get_e2e_keys_handler() @defer.inlineCallbacks def on_POST(self, request, device_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) if device_id is not None: @@ -88,52 +81,15 @@ class KeyUploadServlet(RestServlet): device_id = requester.device_id if device_id is None: - raise synapse.api.errors.SynapseError( + raise SynapseError( 400, "To upload keys, you must pass device_id when authenticating" ) - time_now = self.clock.time_msec() - - # TODO: Validate the JSON to make sure it has the right keys. - device_keys = body.get("device_keys", None) - if device_keys: - logger.info( - "Updating device_keys for device %r for user %s at %d", - device_id, user_id, time_now - ) - # TODO: Sign the JSON with the server key - yield self.store.set_e2e_device_keys( - user_id, device_id, time_now, - encode_canonical_json(device_keys) - ) - - one_time_keys = body.get("one_time_keys", None) - if one_time_keys: - logger.info( - "Adding %d one_time_keys for device %r for user %r at %d", - len(one_time_keys), device_id, user_id, time_now - ) - key_list = [] - for key_id, key_json in one_time_keys.items(): - algorithm, key_id = key_id.split(":") - key_list.append(( - algorithm, key_id, encode_canonical_json(key_json) - )) - - yield self.store.add_e2e_one_time_keys( - user_id, device_id, time_now, key_list - ) - - # the device should have been registered already, but it may have been - # deleted due to a race with a DELETE request. Or we may be using an - # old access_token without an associated device_id. Either way, we - # need to double-check the device is registered to avoid ending up with - # keys without a corresponding device. - self.device_handler.check_device_registered(user_id, device_id) - - result = yield self.store.count_e2e_one_time_keys(user_id, device_id) - defer.returnValue((200, {"one_time_key_counts": result})) + result = yield self.e2e_keys_handler.upload_keys_for_user( + user_id, device_id, body + ) + defer.returnValue((200, result)) class KeyQueryServlet(RestServlet): @@ -195,20 +151,23 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id, device_id): yield self.auth.get_user_by_req(request) + timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.query_devices(body) - defer.returnValue(result) + result = yield self.e2e_keys_handler.query_devices(body, timeout) + defer.returnValue((200, result)) @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): requester = yield self.auth.get_user_by_req(request) + timeout = parse_integer(request, "timeout", 10 * 1000) auth_user_id = requester.user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] result = yield self.e2e_keys_handler.query_devices( - {"device_keys": {user_id: device_ids}} + {"device_keys": {user_id: device_ids}}, + timeout, ) - defer.returnValue(result) + defer.returnValue((200, result)) class OneTimeKeyServlet(RestServlet): @@ -240,59 +199,29 @@ class OneTimeKeyServlet(RestServlet): def __init__(self, hs): super(OneTimeKeyServlet, self).__init__() - self.store = hs.get_datastore() self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.federation = hs.get_replication_layer() - self.is_mine = hs.is_mine + self.e2e_keys_handler = hs.get_e2e_keys_handler() @defer.inlineCallbacks def on_GET(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) - result = yield self.handle_request( - {"one_time_keys": {user_id: {device_id: algorithm}}} + timeout = parse_integer(request, "timeout", 10 * 1000) + result = yield self.e2e_keys_handler.claim_one_time_keys( + {"one_time_keys": {user_id: {device_id: algorithm}}}, + timeout, ) - defer.returnValue(result) + defer.returnValue((200, result)) @defer.inlineCallbacks def on_POST(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) + timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.handle_request(body) - defer.returnValue(result) - - @defer.inlineCallbacks - def handle_request(self, body): - local_query = [] - remote_queries = {} - for user_id, device_keys in body.get("one_time_keys", {}).items(): - user = UserID.from_string(user_id) - if self.is_mine(user): - for device_id, algorithm in device_keys.items(): - local_query.append((user_id, device_id, algorithm)) - else: - remote_queries.setdefault(user.domain, {})[user_id] = ( - device_keys - ) - results = yield self.store.claim_e2e_one_time_keys(local_query) - - json_result = {} - for user_id, device_keys in results.items(): - for device_id, keys in device_keys.items(): - for key_id, json_bytes in keys.items(): - json_result.setdefault(user_id, {})[device_id] = { - key_id: json.loads(json_bytes) - } - - for destination, device_keys in remote_queries.items(): - remote_result = yield self.federation.claim_client_keys( - destination, {"one_time_keys": device_keys} - ) - for user_id, keys in remote_result["one_time_keys"].items(): - if user_id in device_keys: - json_result[user_id] = keys - - defer.returnValue((200, {"one_time_keys": json_result})) + result = yield self.e2e_keys_handler.claim_one_time_keys( + body, + timeout, + ) + defer.returnValue((200, result)) def register_servlets(hs, http_server): diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 003f5ba203..94b2bcc54a 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -219,7 +219,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_handlers[update_name] = update_handler def register_background_index_update(self, update_name, index_name, - table, columns): + table, columns, where_clause=None): """Helper for store classes to do a background index addition To use: @@ -243,14 +243,20 @@ class BackgroundUpdateStore(SQLBaseStore): conc = True else: conc = False - - sql = "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" \ - % { - "conc": "CONCURRENTLY" if conc else "", - "name": index_name, - "table": table, - "columns": ", ".join(columns), - } + # We don't use partial indices on SQLite as it wasn't introduced + # until 3.8, and wheezy has 3.7 + where_clause = None + + sql = ( + "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" + " %(where_clause)s" + ) % { + "conc": "CONCURRENTLY" if conc else "", + "name": index_name, + "table": table, + "columns": ", ".join(columns), + "where_clause": "WHERE " + where_clause if where_clause else "" + } def create_index_concurrently(conn): conn.rollback() diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index b729b7106e..f640e73714 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -128,6 +128,8 @@ class DeviceInboxStore(SQLBaseStore): user_id, stream_id ) + defer.returnValue(stream_id) + def _add_messages_to_local_device_inbox_txn(self, txn, stream_id, messages_by_user_then_device): sql = ( diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index a87d90741a..9cd923eb93 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -26,10 +26,19 @@ logger = logging.getLogger(__name__) class EventPushActionsStore(SQLBaseStore): + EPA_HIGHLIGHT_INDEX = "epa_highlight_index" + def __init__(self, hs): self.stream_ordering_month_ago = None super(EventPushActionsStore, self).__init__(hs) + self.register_background_index_update( + self.EPA_HIGHLIGHT_INDEX, + index_name="event_push_actions_u_highlight", + table="event_push_actions", + columns=["user_id", "stream_ordering"], + ) + def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ Args: diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ed182c8d11..6dc46fa50f 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -189,6 +189,14 @@ class EventsStore(SQLBaseStore): self._background_reindex_fields_sender, ) + self.register_background_index_update( + "event_contains_url_index", + index_name="event_contains_url_index", + table="events", + columns=["room_id", "topological_ordering", "stream_ordering"], + where_clause="contains_url = true AND outlier = false", + ) + self._event_persist_queue = _EventPeristenceQueue() def persist_events(self, events_and_contexts, backfilled=False): diff --git a/synapse/storage/schema/delta/22/receipts_index.sql b/synapse/storage/schema/delta/22/receipts_index.sql index 7bc061dff6..bfc0b3bcaa 100644 --- a/synapse/storage/schema/delta/22/receipts_index.sql +++ b/synapse/storage/schema/delta/22/receipts_index.sql @@ -13,6 +13,10 @@ * limitations under the License. */ +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id ); diff --git a/synapse/storage/schema/delta/28/events_room_stream.sql b/synapse/storage/schema/delta/28/events_room_stream.sql index 200c35e6e2..36609475f1 100644 --- a/synapse/storage/schema/delta/28/events_room_stream.sql +++ b/synapse/storage/schema/delta/28/events_room_stream.sql @@ -13,4 +13,8 @@ * limitations under the License. */ +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX events_room_stream on events(room_id, stream_ordering); diff --git a/synapse/storage/schema/delta/28/public_roms_index.sql b/synapse/storage/schema/delta/28/public_roms_index.sql index ba62a974a4..6c1fd68c5b 100644 --- a/synapse/storage/schema/delta/28/public_roms_index.sql +++ b/synapse/storage/schema/delta/28/public_roms_index.sql @@ -13,4 +13,8 @@ * limitations under the License. */ +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX public_room_index on rooms(is_public); diff --git a/synapse/storage/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/schema/delta/28/receipts_user_id_index.sql index 452a1b3c6c..cb84c69baa 100644 --- a/synapse/storage/schema/delta/28/receipts_user_id_index.sql +++ b/synapse/storage/schema/delta/28/receipts_user_id_index.sql @@ -13,6 +13,10 @@ * limitations under the License. */ +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id ); diff --git a/synapse/storage/schema/delta/29/push_actions.sql b/synapse/storage/schema/delta/29/push_actions.sql index 7e7b09820a..84b21cf813 100644 --- a/synapse/storage/schema/delta/29/push_actions.sql +++ b/synapse/storage/schema/delta/29/push_actions.sql @@ -26,6 +26,10 @@ UPDATE event_push_actions SET stream_ordering = ( UPDATE event_push_actions SET notif = 1, highlight = 0; +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX event_push_actions_rm_tokens on event_push_actions( user_id, room_id, topological_ordering, stream_ordering ); diff --git a/synapse/storage/schema/delta/31/pushers_index.sql b/synapse/storage/schema/delta/31/pushers_index.sql index 9027bccc69..a82add88fd 100644 --- a/synapse/storage/schema/delta/31/pushers_index.sql +++ b/synapse/storage/schema/delta/31/pushers_index.sql @@ -13,6 +13,10 @@ * limitations under the License. */ +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX event_push_actions_stream_ordering on event_push_actions( stream_ordering, user_id ); diff --git a/synapse/storage/schema/delta/35/contains_url.sql b/synapse/storage/schema/delta/35/contains_url.sql new file mode 100644 index 0000000000..6cd123027b --- /dev/null +++ b/synapse/storage/schema/delta/35/contains_url.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + INSERT into background_updates (update_name, progress_json) + VALUES ('event_contains_url_index', '{}'); diff --git a/synapse/storage/schema/delta/35/event_push_actions_index.sql b/synapse/storage/schema/delta/35/event_push_actions_index.sql index 4fc32c351a..2e836d8e9c 100644 --- a/synapse/storage/schema/delta/35/event_push_actions_index.sql +++ b/synapse/storage/schema/delta/35/event_push_actions_index.sql @@ -13,6 +13,5 @@ * limitations under the License. */ - CREATE INDEX event_push_actions_user_id_highlight_stream_ordering on event_push_actions( - user_id, highlight, stream_ordering - ); + INSERT into background_updates (update_name, progress_json) + VALUES ('epa_highlight_index', '{}'); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 0cff0a0cda..fdbdade536 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -306,13 +306,6 @@ class StateStore(SQLBaseStore): defer.returnValue(results) def _get_state_groups_from_groups_txn(self, txn, groups, types=None): - if types is not None: - where_clause = "AND (%s)" % ( - " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), - ) - else: - where_clause = "" - results = {group: {} for group in groups} if isinstance(self.database_engine, PostgresEngine): # Temporarily disable sequential scans in this transaction. This is @@ -342,20 +335,43 @@ class StateStore(SQLBaseStore): WHERE state_group IN ( SELECT state_group FROM state ) - %s; - """) % (where_clause,) - - for group in groups: - args = [group] - if types is not None: - args.extend([i for typ in types for i in typ]) + %s + """) - txn.execute(sql, args) - rows = self.cursor_to_dict(txn) - for row in rows: - key = (row["type"], row["state_key"]) - results[group][key] = row["event_id"] + # Turns out that postgres doesn't like doing a list of OR's and + # is about 1000x slower, so we just issue a query for each specific + # type seperately. + if types: + clause_to_args = [ + ( + "AND type = ? AND state_key = ?", + (etype, state_key) + ) + for etype, state_key in types + ] + else: + # If types is None we fetch all the state, and so just use an + # empty where clause with no extra args. + clause_to_args = [("", [])] + + for where_clause, where_args in clause_to_args: + for group in groups: + args = [group] + args.extend(where_args) + + txn.execute(sql % (where_clause,), args) + rows = self.cursor_to_dict(txn) + for row in rows: + key = (row["type"], row["state_key"]) + results[group][key] = row["event_id"] else: + if types is not None: + where_clause = "AND (%s)" % ( + " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), + ) + else: + where_clause = "" + # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) for group in groups: @@ -804,7 +820,7 @@ class StateStore(SQLBaseStore): def reindex_txn(txn): if isinstance(self.database_engine, PostgresEngine): txn.execute( - "CREATE INDEX state_groups_state_type_idx" + "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" " ON state_groups_state(state_group, type, state_key)" ) txn.execute( |