diff options
author | Erik Johnston <erik@matrix.org> | 2016-09-19 17:20:25 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2016-09-19 17:20:25 +0100 |
commit | 88acb99747c05f43f4563774f498993b27915493 (patch) | |
tree | 3184c8b912fdf11d29f7254ab879447f677e353b /synapse | |
parent | Merge branch 'release-v0.17.3' of github.com:matrix-org/synapse (diff) | |
parent | Bump version and changelog (diff) | |
download | synapse-88acb99747c05f43f4563774f498993b27915493.tar.xz |
Merge branch 'release-v0.18.0' of github.com:matrix-org/synapse v0.18.0
Diffstat (limited to 'synapse')
76 files changed, 3024 insertions, 777 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index b778cd65c9..41745170a1 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.17.3" +__version__ = "0.18.0" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index dcda40863f..98a50f0948 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -583,12 +583,15 @@ class Auth(object): """ # Can optionally look elsewhere in the request (e.g. headers) try: - user_id = yield self._get_appservice_user_id(request.args) + user_id = yield self._get_appservice_user_id(request) if user_id: request.authenticated_entity = user_id defer.returnValue(synapse.types.create_requester(user_id)) - access_token = request.args["access_token"][0] + access_token = get_access_token_from_request( + request, self.TOKEN_NOT_FOUND_HTTP_STATUS + ) + user_info = yield self.get_user_by_access_token(access_token, rights) user = user_info["user"] token_id = user_info["token_id"] @@ -629,17 +632,19 @@ class Auth(object): ) @defer.inlineCallbacks - def _get_appservice_user_id(self, request_args): + def _get_appservice_user_id(self, request): app_service = yield self.store.get_app_service_by_token( - request_args["access_token"][0] + get_access_token_from_request( + request, self.TOKEN_NOT_FOUND_HTTP_STATUS + ) ) if app_service is None: defer.returnValue(None) - if "user_id" not in request_args: + if "user_id" not in request.args: defer.returnValue(app_service.sender) - user_id = request_args["user_id"][0] + user_id = request.args["user_id"][0] if app_service.sender == user_id: defer.returnValue(app_service.sender) @@ -833,7 +838,9 @@ class Auth(object): @defer.inlineCallbacks def get_appservice_by_req(self, request): try: - token = request.args["access_token"][0] + token = get_access_token_from_request( + request, self.TOKEN_NOT_FOUND_HTTP_STATUS + ) service = yield self.store.get_app_service_by_token(token) if not service: logger.warn("Unrecognised appservice access token: %s" % (token,)) @@ -1142,3 +1149,40 @@ class Auth(object): "This server requires you to be a moderator in the room to" " edit its room list entry" ) + + +def has_access_token(request): + """Checks if the request has an access_token. + + Returns: + bool: False if no access_token was given, True otherwise. + """ + query_params = request.args.get("access_token") + return bool(query_params) + + +def get_access_token_from_request(request, token_not_found_http_status=401): + """Extracts the access_token from the request. + + Args: + request: The http request. + token_not_found_http_status(int): The HTTP status code to set in the + AuthError if the token isn't found. This is used in some of the + legacy APIs to change the status code to 403 from the default of + 401 since some of the old clients depended on auth errors returning + 403. + Returns: + str: The access_token + Raises: + AuthError: If there isn't an access_token in the request. + """ + query_params = request.args.get("access_token") + # Try to get the access_token from the query params. + if not query_params: + raise AuthError( + token_not_found_http_status, + "Missing access token.", + errcode=Codes.MISSING_TOKEN + ) + + return query_params[0] diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 57587aed25..4e62a84b28 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -187,6 +187,7 @@ def start(config_options): def start(): ps.replicate() ps.get_datastore().start_profiling() + ps.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py new file mode 100644 index 0000000000..9fccc73db3 --- /dev/null +++ b/synapse/app/client_reader.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse + +from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.config.logger import setup_logging +from synapse.http.site import SynapseSite +from synapse.http.server import JsonResource +from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore +from synapse.replication.slave.storage.events import SlavedEventStore +from synapse.replication.slave.storage.keys import SlavedKeyStore +from synapse.replication.slave.storage.room import RoomStore +from synapse.replication.slave.storage.directory import DirectoryStore +from synapse.replication.slave.storage.registration import SlavedRegistrationStore +from synapse.rest.client.v1.room import PublicRoomListRestServlet +from synapse.server import HomeServer +from synapse.storage.client_ips import ClientIpStore +from synapse.storage.engines import create_engine +from synapse.util.async import sleep +from synapse.util.httpresourcetree import create_resource_tree +from synapse.util.logcontext import LoggingContext +from synapse.util.manhole import manhole +from synapse.util.rlimit import change_resource_limit +from synapse.util.versionstring import get_version_string +from synapse.crypto import context_factory + + +from twisted.internet import reactor, defer +from twisted.web.resource import Resource + +from daemonize import Daemonize + +import sys +import logging +import gc + +logger = logging.getLogger("synapse.app.client_reader") + + +class ClientReaderSlavedStore( + SlavedEventStore, + SlavedKeyStore, + RoomStore, + DirectoryStore, + SlavedApplicationServiceStore, + SlavedRegistrationStore, + BaseSlavedStore, + ClientIpStore, # After BaseSlavedStore because the constructor is different +): + pass + + +class ClientReaderServer(HomeServer): + def get_db_conn(self, run_new_connection=True): + # Any param beginning with cp_ is a parameter for adbapi, and should + # not be passed to the database engine. + db_params = { + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = self.database_engine.module.connect(**db_params) + + if run_new_connection: + self.database_engine.on_new_connection(db_conn) + return db_conn + + def setup(self): + logger.info("Setting up.") + self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self) + logger.info("Finished setting up.") + + def _listen_http(self, listener_config): + port = listener_config["port"] + bind_address = listener_config.get("bind_address", "") + site_tag = listener_config.get("tag", port) + resources = {} + for res in listener_config["resources"]: + for name in res["names"]: + if name == "metrics": + resources[METRICS_PREFIX] = MetricsResource(self) + elif name == "client": + resource = JsonResource(self, canonical_json=False) + PublicRoomListRestServlet(self).register(resource) + resources.update({ + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + }) + + root_resource = create_resource_tree(resources, Resource()) + reactor.listenTCP( + port, + SynapseSite( + "synapse.access.http.%s" % (site_tag,), + site_tag, + listener_config, + root_resource, + ), + interface=bind_address + ) + logger.info("Synapse client reader now listening on port %d", port) + + def start_listening(self, listeners): + for listener in listeners: + if listener["type"] == "http": + self._listen_http(listener) + elif listener["type"] == "manhole": + reactor.listenTCP( + listener["port"], + manhole( + username="matrix", + password="rabbithole", + globals={"hs": self}, + ), + interface=listener.get("bind_address", '127.0.0.1') + ) + else: + logger.warn("Unrecognized listener type: %s", listener["type"]) + + @defer.inlineCallbacks + def replicate(self): + http_client = self.get_simple_http_client() + store = self.get_datastore() + replication_url = self.config.worker_replication_url + + while True: + try: + args = store.stream_positions() + args["timeout"] = 30000 + result = yield http_client.get_json(replication_url, args=args) + yield store.process_replication(result) + except: + logger.exception("Error replicating from %r", replication_url) + yield sleep(5) + + +def start(config_options): + try: + config = HomeServerConfig.load_config( + "Synapse client reader", config_options + ) + except ConfigError as e: + sys.stderr.write("\n" + e.message + "\n") + sys.exit(1) + + assert config.worker_app == "synapse.app.client_reader" + + setup_logging(config.worker_log_config, config.worker_log_file) + + database_engine = create_engine(config.database_config) + + tls_server_context_factory = context_factory.ServerContextFactory(config) + + ss = ClientReaderServer( + config.server_name, + db_config=config.database_config, + tls_server_context_factory=tls_server_context_factory, + config=config, + version_string="Synapse/" + get_version_string(synapse), + database_engine=database_engine, + ) + + ss.setup() + ss.get_handlers() + ss.start_listening(config.worker_listeners) + + def run(): + with LoggingContext("run"): + logger.info("Running") + change_resource_limit(config.soft_file_limit) + if config.gc_thresholds: + gc.set_threshold(*config.gc_thresholds) + reactor.run() + + def start(): + ss.get_state_handler().start_caching() + ss.get_datastore().start_profiling() + ss.replicate() + + reactor.callWhenRunning(start) + + if config.worker_daemonize: + daemon = Daemonize( + app="synapse-client-reader", + pid=config.worker_pid_file, + action=run, + auto_close_fds=False, + verbose=True, + logger=logger, + ) + daemon.start() + else: + run() + + +if __name__ == '__main__': + with LoggingContext("main"): + start(sys.argv[1:]) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 7355499ae2..1f5ae1937e 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -182,6 +182,7 @@ def start(config_options): reactor.run() def start(): + ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() ss.replicate() diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 9d4c4a0750..6e5ec01c6c 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -188,6 +188,7 @@ def start(config_options): reactor.run() def start(): + ss.get_state_handler().start_caching() ss.get_datastore().start_profiling() ss.replicate() diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 8d755a4b33..d59f4a571c 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -276,6 +276,7 @@ def start(config_options): ps.replicate() ps.get_pusherpool().start() ps.get_datastore().start_profiling() + ps.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 07d3d047c6..46d390fd0f 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -242,6 +242,9 @@ class SynchrotronTyping(object): self._room_typing = {} def stream_positions(self): + # We must update this typing token from the response of the previous + # sync. In particular, the stream id may "reset" back to zero/a low + # value which we *must* use for the next replication request. return {"typing": self._latest_room_serial} def process_replication(self, result): @@ -462,6 +465,7 @@ def start(config_options): def start(): ss.get_datastore().start_profiling() ss.replicate() + ss.get_state_handler().start_caching() reactor.callWhenRunning(start) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index cc4af23962..b0eb0c6d9d 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -32,6 +32,14 @@ HOUR_IN_MS = 60 * 60 * 1000 APP_SERVICE_PREFIX = "/_matrix/app/unstable" +def _is_valid_3pe_metadata(info): + if "instances" not in info: + return False + if not isinstance(info["instances"], list): + return False + return True + + def _is_valid_3pe_result(r, field): if not isinstance(r, dict): return False @@ -162,11 +170,18 @@ class ApplicationServiceApi(SimpleHttpClient): urllib.quote(protocol) ) try: - defer.returnValue((yield self.get_json(uri, {}))) + info = yield self.get_json(uri, {}) + + if not _is_valid_3pe_metadata(info): + logger.warning("query_3pe_protocol to %s did not return a" + " valid result", uri) + defer.returnValue(None) + + defer.returnValue(info) except Exception as ex: logger.warning("query_3pe_protocol to %s threw exception %s", uri, ex) - defer.returnValue({}) + defer.returnValue(None) key = (service.id, protocol) return self.protocol_meta_cache.get(key) or ( diff --git a/synapse/config/server.py b/synapse/config/server.py index 51eaf423ce..ed5417d0c3 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -29,7 +29,6 @@ class ServerConfig(Config): self.user_agent_suffix = config.get("user_agent_suffix") self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.public_baseurl = config.get("public_baseurl") - self.secondary_directory_servers = config.get("secondary_directory_servers", []) if self.public_baseurl is not None: if self.public_baseurl[-1] != '/': @@ -142,14 +141,6 @@ class ServerConfig(Config): # The GC threshold parameters to pass to `gc.set_threshold`, if defined # gc_thresholds: [700, 10, 10] - # A list of other Home Servers to fetch the public room directory from - # and include in the public room directory of this home server - # This is a temporary stopgap solution to populate new server with a - # list of rooms until there exists a good solution of a decentralized - # room directory. - # secondary_directory_servers: - # - matrix.org - # List of ports that Synapse should listen on, their purpose and their # configuration. listeners: diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index e895b1c450..11605b34a3 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -15,9 +15,30 @@ class EventContext(object): + __slots__ = [ + "current_state_ids", + "prev_state_ids", + "state_group", + "rejected", + "push_actions", + "prev_group", + "delta_ids", + "prev_state_events", + ] + def __init__(self): + # The current state including the current event self.current_state_ids = None + # The current state excluding the current event self.prev_state_ids = None self.state_group = None + self.rejected = False self.push_actions = [] + + # A previously persisted state group and a delta between that + # and this state. + self.prev_group = None + self.delta_ids = None + + self.prev_state_events = None diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 627acc6a4f..06d0320b1a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -24,7 +24,6 @@ from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, ) from synapse.util import unwrapFirstError -from synapse.util.async import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred @@ -122,8 +121,12 @@ class FederationClient(FederationBase): pdu.event_id ) + def send_presence(self, destination, states): + if destination != self.server_name: + self._transaction_queue.enqueue_presence(destination, states) + @log_function - def send_edu(self, destination, edu_type, content): + def send_edu(self, destination, edu_type, content, key=None): edu = Edu( origin=self.server_name, destination=destination, @@ -134,10 +137,16 @@ class FederationClient(FederationBase): sent_edus_counter.inc() # TODO, add errback, etc. - self._transaction_queue.enqueue_edu(edu) + self._transaction_queue.enqueue_edu(edu, key=key) return defer.succeed(None) @log_function + def send_device_messages(self, destination): + """Sends the device messages in the local database to the remote + destination""" + self._transaction_queue.enqueue_device_messages(destination) + + @log_function def send_failure(self, failure, destination): self._transaction_queue.enqueue_failure(failure, destination) return defer.succeed(None) @@ -166,7 +175,7 @@ class FederationClient(FederationBase): ) @log_function - def query_client_keys(self, destination, content): + def query_client_keys(self, destination, content, timeout): """Query device keys for a device hosted on a remote server. Args: @@ -178,10 +187,12 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_device_keys") - return self.transport_layer.query_client_keys(destination, content) + return self.transport_layer.query_client_keys( + destination, content, timeout + ) @log_function - def claim_client_keys(self, destination, content): + def claim_client_keys(self, destination, content, timeout): """Claims one-time keys for a device hosted on a remote server. Args: @@ -193,7 +204,9 @@ class FederationClient(FederationBase): response """ sent_queries_counter.inc("client_one_time_keys") - return self.transport_layer.claim_client_keys(destination, content) + return self.transport_layer.claim_client_keys( + destination, content, timeout + ) @defer.inlineCallbacks @log_function @@ -471,7 +484,7 @@ class FederationClient(FederationBase): defer.DeferredList(deferreds, consumeErrors=True) ) for success, result in res: - if success: + if success and result: signed_events.append(result) batch.discard(result.event_id) @@ -705,24 +718,14 @@ class FederationClient(FederationBase): raise RuntimeError("Failed to send to any server.") - @defer.inlineCallbacks - def get_public_rooms(self, destinations): - results_by_server = {} - - @defer.inlineCallbacks - def _get_result(s): - if s == self.server_name: - defer.returnValue() - - try: - result = yield self.transport_layer.get_public_rooms(s) - results_by_server[s] = result - except: - logger.exception("Error getting room list from server %r", s) - - yield concurrently_execute(_get_result, destinations, 3) + def get_public_rooms(self, destination, limit=None, since_token=None, + search_filter=None): + if destination == self.server_name: + return - defer.returnValue(results_by_server) + return self.transport_layer.get_public_rooms( + destination, limit, since_token, search_filter + ) @defer.inlineCallbacks def query_auth(self, destination, room_id, event_id, local_auth): diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 5621655098..3fa7b2315c 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -188,7 +188,7 @@ class FederationServer(FederationBase): except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) except Exception as e: - logger.exception("Failed to handle edu %r", edu_type, e) + logger.exception("Failed to handle edu %r", edu_type) else: logger.warn("Received EDU of type %s with no handler", edu_type) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index cb2ef0210c..f8ca93e4c3 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -17,7 +17,7 @@ from twisted.internet import defer from .persistence import TransactionActions -from .units import Transaction +from .units import Transaction, Edu from synapse.api.errors import HttpResponseException from synapse.util.async import run_on_reactor @@ -26,6 +26,7 @@ from synapse.util.retryutils import ( get_retry_limiter, NotRetryingDestination, ) from synapse.util.metrics import measure_func +from synapse.handlers.presence import format_user_presence_state import synapse.metrics import logging @@ -69,18 +70,28 @@ class TransactionQueue(object): # destination -> list of tuple(edu, deferred) self.pending_edus_by_dest = edus = {} + # Presence needs to be separate as we send single aggragate EDUs + self.pending_presence_by_dest = presence = {} + self.pending_edus_keyed_by_dest = edus_keyed = {} + metrics.register_callback( "pending_pdus", lambda: sum(map(len, pdus.values())), ) metrics.register_callback( "pending_edus", - lambda: sum(map(len, edus.values())), + lambda: ( + sum(map(len, edus.values())) + + sum(map(len, presence.values())) + + sum(map(len, edus_keyed.values())) + ), ) # destination -> list of tuple(failure, deferred) self.pending_failures_by_dest = {} + self.last_device_stream_id_by_dest = {} + # HACK to get unique tx id self._next_txn_id = int(self.clock.time_msec()) @@ -128,13 +139,27 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_edu(self, edu): + def enqueue_presence(self, destination, states): + self.pending_presence_by_dest.setdefault(destination, {}).update({ + state.user_id: state for state in states + }) + + preserve_context_over_fn( + self._attempt_new_transaction, destination + ) + + def enqueue_edu(self, edu, key=None): destination = edu.destination if not self.can_send_to(destination): return - self.pending_edus_by_dest.setdefault(destination, []).append(edu) + if key: + self.pending_edus_keyed_by_dest.setdefault( + destination, {} + )[(edu.edu_type, key)] = edu + else: + self.pending_edus_by_dest.setdefault(destination, []).append(edu) preserve_context_over_fn( self._attempt_new_transaction, destination @@ -155,179 +180,261 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) + def enqueue_device_messages(self, destination): + if destination == self.server_name or destination == "localhost": + return + + if not self.can_send_to(destination): + return + + preserve_context_over_fn( + self._attempt_new_transaction, destination + ) + @defer.inlineCallbacks def _attempt_new_transaction(self, destination): - yield run_on_reactor() - while True: - # list of (pending_pdu, deferred, order) - if destination in self.pending_transactions: - # XXX: pending_transactions can get stuck on by a never-ending - # request at which point pending_pdus_by_dest just keeps growing. - # we need application-layer timeouts of some flavour of these - # requests - logger.debug( - "TX [%s] Transaction already in progress", - destination - ) - return + # list of (pending_pdu, deferred, order) + if destination in self.pending_transactions: + # XXX: pending_transactions can get stuck on by a never-ending + # request at which point pending_pdus_by_dest just keeps growing. + # we need application-layer timeouts of some flavour of these + # requests + logger.debug( + "TX [%s] Transaction already in progress", + destination + ) + return + + try: + self.pending_transactions[destination] = 1 + + yield run_on_reactor() + + while True: + pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_presence = self.pending_presence_by_dest.pop(destination, {}) + pending_failures = self.pending_failures_by_dest.pop(destination, []) + + pending_edus.extend( + self.pending_edus_keyed_by_dest.pop(destination, {}).values() + ) + + limiter = yield get_retry_limiter( + destination, + self.clock, + self.store, + ) + + device_message_edus, device_stream_id = ( + yield self._get_new_device_messages(destination) + ) + + pending_edus.extend(device_message_edus) + if pending_presence: + pending_edus.append( + Edu( + origin=self.server_name, + destination=destination, + edu_type="m.presence", + content={ + "push": [ + format_user_presence_state( + presence, self.clock.time_msec() + ) + for presence in pending_presence.values() + ] + }, + ) + ) - pending_pdus = self.pending_pdus_by_dest.pop(destination, []) - pending_edus = self.pending_edus_by_dest.pop(destination, []) - pending_failures = self.pending_failures_by_dest.pop(destination, []) + if pending_pdus: + logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", + destination, len(pending_pdus)) - if pending_pdus: - logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", - destination, len(pending_pdus)) + if not pending_pdus and not pending_edus and not pending_failures: + logger.debug("TX [%s] Nothing to send", destination) + self.last_device_stream_id_by_dest[destination] = ( + device_stream_id + ) + return - if not pending_pdus and not pending_edus and not pending_failures: - logger.debug("TX [%s] Nothing to send", destination) - return + success = yield self._send_new_transaction( + destination, pending_pdus, pending_edus, pending_failures, + device_stream_id, + should_delete_from_device_stream=bool(device_message_edus), + limiter=limiter, + ) + if not success: + break + except NotRetryingDestination: + logger.info( + "TX [%s] not ready for retry yet - " + "dropping transaction for now", + destination, + ) + finally: + # We want to be *very* sure we delete this after we stop processing + self.pending_transactions.pop(destination, None) - yield self._send_new_transaction( - destination, pending_pdus, pending_edus, pending_failures + @defer.inlineCallbacks + def _get_new_device_messages(self, destination): + last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0) + to_device_stream_id = self.store.get_to_device_stream_token() + contents, stream_id = yield self.store.get_new_device_msgs_for_remote( + destination, last_device_stream_id, to_device_stream_id + ) + edus = [ + Edu( + origin=self.server_name, + destination=destination, + edu_type="m.direct_to_device", + content=content, ) + for content in contents + ] + defer.returnValue((edus, stream_id)) @measure_func("_send_new_transaction") @defer.inlineCallbacks def _send_new_transaction(self, destination, pending_pdus, pending_edus, - pending_failures): + pending_failures, device_stream_id, + should_delete_from_device_stream, limiter): - # Sort based on the order field - pending_pdus.sort(key=lambda t: t[1]) - pdus = [x[0] for x in pending_pdus] - edus = pending_edus - failures = [x.get_dict() for x in pending_failures] + # Sort based on the order field + pending_pdus.sort(key=lambda t: t[1]) + pdus = [x[0] for x in pending_pdus] + edus = pending_edus + failures = [x.get_dict() for x in pending_failures] - try: - self.pending_transactions[destination] = 1 + success = True - logger.debug("TX [%s] _attempt_new_transaction", destination) + try: + logger.debug("TX [%s] _attempt_new_transaction", destination) - txn_id = str(self._next_txn_id) + txn_id = str(self._next_txn_id) - limiter = yield get_retry_limiter( - destination, - self.clock, - self.store, - ) + logger.debug( + "TX [%s] {%s} Attempting new transaction" + " (pdus: %d, edus: %d, failures: %d)", + destination, txn_id, + len(pdus), + len(edus), + len(failures) + ) - logger.debug( - "TX [%s] {%s} Attempting new transaction" - " (pdus: %d, edus: %d, failures: %d)", - destination, txn_id, - len(pending_pdus), - len(pending_edus), - len(pending_failures) - ) + logger.debug("TX [%s] Persisting transaction...", destination) - logger.debug("TX [%s] Persisting transaction...", destination) + transaction = Transaction.create_new( + origin_server_ts=int(self.clock.time_msec()), + transaction_id=txn_id, + origin=self.server_name, + destination=destination, + pdus=pdus, + edus=edus, + pdu_failures=failures, + ) - transaction = Transaction.create_new( - origin_server_ts=int(self.clock.time_msec()), - transaction_id=txn_id, - origin=self.server_name, - destination=destination, - pdus=pdus, - edus=edus, - pdu_failures=failures, - ) + self._next_txn_id += 1 - self._next_txn_id += 1 + yield self.transaction_actions.prepare_to_send(transaction) + + logger.debug("TX [%s] Persisted transaction", destination) + logger.info( + "TX [%s] {%s} Sending transaction [%s]," + " (PDUs: %d, EDUs: %d, failures: %d)", + destination, txn_id, + transaction.transaction_id, + len(pdus), + len(edus), + len(failures), + ) - yield self.transaction_actions.prepare_to_send(transaction) + with limiter: + # Actually send the transaction + + # FIXME (erikj): This is a bit of a hack to make the Pdu age + # keys work + def json_data_cb(): + data = transaction.get_dict() + now = int(self.clock.time_msec()) + if "pdus" in data: + for p in data["pdus"]: + if "age_ts" in p: + unsigned = p.setdefault("unsigned", {}) + unsigned["age"] = now - int(p["age_ts"]) + del p["age_ts"] + return data + + try: + response = yield self.transport_layer.send_transaction( + transaction, json_data_cb + ) + code = 200 + + if response: + for e_id, r in response.get("pdus", {}).items(): + if "error" in r: + logger.warn( + "Transaction returned error for %s: %s", + e_id, r, + ) + except HttpResponseException as e: + code = e.code + response = e.response - logger.debug("TX [%s] Persisted transaction", destination) logger.info( - "TX [%s] {%s} Sending transaction [%s]," - " (PDUs: %d, EDUs: %d, failures: %d)", - destination, txn_id, - transaction.transaction_id, - len(pending_pdus), - len(pending_edus), - len(pending_failures), + "TX [%s] {%s} got %d response", + destination, txn_id, code ) - with limiter: - # Actually send the transaction - - # FIXME (erikj): This is a bit of a hack to make the Pdu age - # keys work - def json_data_cb(): - data = transaction.get_dict() - now = int(self.clock.time_msec()) - if "pdus" in data: - for p in data["pdus"]: - if "age_ts" in p: - unsigned = p.setdefault("unsigned", {}) - unsigned["age"] = now - int(p["age_ts"]) - del p["age_ts"] - return data - - try: - response = yield self.transport_layer.send_transaction( - transaction, json_data_cb - ) - code = 200 - - if response: - for e_id, r in response.get("pdus", {}).items(): - if "error" in r: - logger.warn( - "Transaction returned error for %s: %s", - e_id, r, - ) - except HttpResponseException as e: - code = e.code - response = e.response - - logger.info( - "TX [%s] {%s} got %d response", - destination, txn_id, code - ) + logger.debug("TX [%s] Sent transaction", destination) + logger.debug("TX [%s] Marking as delivered...", destination) - logger.debug("TX [%s] Sent transaction", destination) - logger.debug("TX [%s] Marking as delivered...", destination) + yield self.transaction_actions.delivered( + transaction, code, response + ) - yield self.transaction_actions.delivered( - transaction, code, response - ) + logger.debug("TX [%s] Marked as delivered", destination) - logger.debug("TX [%s] Marked as delivered", destination) + if code != 200: + for p in pdus: + logger.info( + "Failed to send event %s to %s", p.event_id, destination + ) + success = False + else: + # Remove the acknowledged device messages from the database + if should_delete_from_device_stream: + yield self.store.delete_device_msgs_for_remote( + destination, device_stream_id + ) + self.last_device_stream_id_by_dest[destination] = device_stream_id + except RuntimeError as e: + # We capture this here as there as nothing actually listens + # for this finishing functions deferred. + logger.warn( + "TX [%s] Problem in _attempt_transaction: %s", + destination, + e, + ) - if code != 200: - for p in pdus: - logger.info( - "Failed to send event %s to %s", p.event_id, destination - ) - except NotRetryingDestination: - logger.info( - "TX [%s] not ready for retry yet - " - "dropping transaction for now", - destination, - ) - except RuntimeError as e: - # We capture this here as there as nothing actually listens - # for this finishing functions deferred. - logger.warn( - "TX [%s] Problem in _attempt_transaction: %s", - destination, - e, - ) + success = False + + for p in pdus: + logger.info("Failed to send event %s to %s", p.event_id, destination) + except Exception as e: + # We capture this here as there as nothing actually listens + # for this finishing functions deferred. + logger.warn( + "TX [%s] Problem in _attempt_transaction: %s", + destination, + e, + ) - for p in pdus: - logger.info("Failed to send event %s to %s", p.event_id, destination) - except Exception as e: - # We capture this here as there as nothing actually listens - # for this finishing functions deferred. - logger.warn( - "TX [%s] Problem in _attempt_transaction: %s", - destination, - e, - ) + success = False - for p in pdus: - logger.info("Failed to send event %s to %s", p.event_id, destination) + for p in pdus: + logger.info("Failed to send event %s to %s", p.event_id, destination) - finally: - # We want to be *very* sure we delete this after we stop processing - self.pending_transactions.pop(destination, None) + defer.returnValue(success) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 3d088e43cb..db45c7826c 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -248,12 +248,22 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def get_public_rooms(self, remote_server): + def get_public_rooms(self, remote_server, limit, since_token, + search_filter=None): path = PREFIX + "/publicRooms" + args = {} + if limit: + args["limit"] = [str(limit)] + if since_token: + args["since"] = [since_token] + + # TODO(erikj): Actually send the search_filter across federation. + response = yield self.client.get_json( destination=remote_server, path=path, + args=args, ) defer.returnValue(response) @@ -298,7 +308,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def query_client_keys(self, destination, query_content): + def query_client_keys(self, destination, query_content, timeout): """Query the device keys for a list of user ids hosted on a remote server. @@ -327,12 +337,13 @@ class TransportLayerClient(object): destination=destination, path=path, data=query_content, + timeout=timeout, ) defer.returnValue(content) @defer.inlineCallbacks @log_function - def claim_client_keys(self, destination, query_content): + def claim_client_keys(self, destination, query_content, timeout): """Claim one-time keys for a list of devices hosted on a remote server. Request: @@ -363,6 +374,7 @@ class TransportLayerClient(object): destination=destination, path=path, data=query_content, + timeout=timeout, ) defer.returnValue(content) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 37c0d4fbc4..fec337be64 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -18,7 +18,9 @@ from twisted.internet import defer from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.errors import Codes, SynapseError from synapse.http.server import JsonResource -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import ( + parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, +) from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string @@ -554,7 +556,11 @@ class PublicRoomList(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query): - data = yield self.room_list_handler.get_local_public_room_list() + limit = parse_integer_from_args(query, "limit", 0) + since_token = parse_string_from_args(query, "since", None) + data = yield self.room_list_handler.get_local_public_room_list( + limit, since_token + ) defer.returnValue((200, data)) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index b440280b74..88fa0bb2e4 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -176,12 +176,41 @@ class ApplicationServicesHandler(object): defer.returnValue(ret) @defer.inlineCallbacks - def get_3pe_protocols(self): + def get_3pe_protocols(self, only_protocol=None): services = yield self.store.get_app_services() protocols = {} + + # Collect up all the individual protocol responses out of the ASes for s in services: for p in s.protocols: - protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p) + if only_protocol is not None and p != only_protocol: + continue + + if p not in protocols: + protocols[p] = [] + + info = yield self.appservice_api.get_3pe_protocol(s, p) + + if info is not None: + protocols[p].append(info) + + def _merge_instances(infos): + if not infos: + return {} + + # Merge the 'instances' lists of multiple results, but just take + # the other fields from the first as they ought to be identical + # copy the result so as not to corrupt the cached one + combined = dict(infos[0]) + combined["instances"] = list(combined["instances"]) + + for info in infos[1:]: + combined["instances"].extend(info["instances"]) + + return combined + + for p in protocols.keys(): + protocols[p] = _merge_instances(protocols[p]) defer.returnValue(protocols) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 8d630c6b1a..aa68755936 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -58,7 +58,7 @@ class DeviceHandler(BaseHandler): attempts = 0 while attempts < 5: try: - device_id = stringutils.random_string_with_symbols(16) + device_id = stringutils.random_string(10).upper() yield self.store.store_device( user_id=user_id, device_id=device_id, diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py new file mode 100644 index 0000000000..c5368e5df2 --- /dev/null +++ b/synapse/handlers/devicemessage.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.types import get_domain_from_id +from synapse.util.stringutils import random_string + + +logger = logging.getLogger(__name__) + + +class DeviceMessageHandler(object): + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + self.is_mine_id = hs.is_mine_id + self.federation = hs.get_replication_layer() + + self.federation.register_edu_handler( + "m.direct_to_device", self.on_direct_to_device_edu + ) + + @defer.inlineCallbacks + def on_direct_to_device_edu(self, origin, content): + local_messages = {} + sender_user_id = content["sender"] + if origin != get_domain_from_id(sender_user_id): + logger.warn( + "Dropping device message from %r with spoofed sender %r", + origin, sender_user_id + ) + message_type = content["type"] + message_id = content["message_id"] + for user_id, by_device in content["messages"].items(): + messages_by_device = { + device_id: { + "content": message_content, + "type": message_type, + "sender": sender_user_id, + } + for device_id, message_content in by_device.items() + } + if messages_by_device: + local_messages[user_id] = messages_by_device + + stream_id = yield self.store.add_messages_from_remote_to_device_inbox( + origin, message_id, local_messages + ) + + self.notifier.on_new_event( + "to_device_key", stream_id, users=local_messages.keys() + ) + + @defer.inlineCallbacks + def send_device_message(self, sender_user_id, message_type, messages): + + local_messages = {} + remote_messages = {} + for user_id, by_device in messages.items(): + if self.is_mine_id(user_id): + messages_by_device = { + device_id: { + "content": message_content, + "type": message_type, + "sender": sender_user_id, + } + for device_id, message_content in by_device.items() + } + if messages_by_device: + local_messages[user_id] = messages_by_device + else: + destination = get_domain_from_id(user_id) + remote_messages.setdefault(destination, {})[user_id] = by_device + + message_id = random_string(16) + + remote_edu_contents = {} + for destination, messages in remote_messages.items(): + remote_edu_contents[destination] = { + "messages": messages, + "sender": sender_user_id, + "type": message_type, + "message_id": message_id, + } + + stream_id = yield self.store.add_messages_to_device_inbox( + local_messages, remote_edu_contents + ) + + self.notifier.on_new_event( + "to_device_key", stream_id, users=local_messages.keys() + ) + + for destination in remote_messages.keys(): + # Enqueue a new federation transaction to send the new + # device messages to each remote destination. + self.federation.send_device_messages(destination) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 2c7bfd91ed..fd11935b40 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -13,14 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections -import json +import ujson as json import logging +from canonicaljson import encode_canonical_json from twisted.internet import defer -from synapse.api import errors -import synapse.types +from synapse.api.errors import SynapseError, CodeMessageException +from synapse.types import get_domain_from_id +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination logger = logging.getLogger(__name__) @@ -29,8 +31,9 @@ class E2eKeysHandler(object): def __init__(self, hs): self.store = hs.get_datastore() self.federation = hs.get_replication_layer() + self.device_handler = hs.get_device_handler() self.is_mine_id = hs.is_mine_id - self.server_name = hs.hostname + self.clock = hs.get_clock() # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the @@ -40,7 +43,7 @@ class E2eKeysHandler(object): ) @defer.inlineCallbacks - def query_devices(self, query_body): + def query_devices(self, query_body, timeout): """ Handle a device key query from a client { @@ -63,27 +66,60 @@ class E2eKeysHandler(object): # separate users by domain. # make a map from domain to user_id to device_ids - queries_by_domain = collections.defaultdict(dict) + local_query = {} + remote_queries = {} + for user_id, device_ids in device_keys_query.items(): - user = synapse.types.UserID.from_string(user_id) - queries_by_domain[user.domain][user_id] = device_ids + if self.is_mine_id(user_id): + local_query[user_id] = device_ids + else: + domain = get_domain_from_id(user_id) + remote_queries.setdefault(domain, {})[user_id] = device_ids # do the queries - # TODO: do these in parallel + failures = {} results = {} - for destination, destination_query in queries_by_domain.items(): - if destination == self.server_name: - res = yield self.query_local_devices(destination_query) - else: - res = yield self.federation.query_client_keys( - destination, {"device_keys": destination_query} - ) - res = res["device_keys"] - for user_id, keys in res.items(): - if user_id in destination_query: + if local_query: + local_result = yield self.query_local_devices(local_query) + for user_id, keys in local_result.items(): + if user_id in local_query: results[user_id] = keys - defer.returnValue((200, {"device_keys": results})) + @defer.inlineCallbacks + def do_remote_query(destination): + destination_query = remote_queries[destination] + try: + limiter = yield get_retry_limiter( + destination, self.clock, self.store + ) + with limiter: + remote_result = yield self.federation.query_client_keys( + destination, + {"device_keys": destination_query}, + timeout=timeout + ) + + for user_id, keys in remote_result["device_keys"].items(): + if user_id in destination_query: + results[user_id] = keys + + except CodeMessageException as e: + failures[destination] = { + "status": e.code, "message": e.message + } + except NotRetryingDestination as e: + failures[destination] = { + "status": 503, "message": "Not ready for retry", + } + + yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(do_remote_query)(destination) + for destination in remote_queries + ])) + + defer.returnValue({ + "device_keys": results, "failures": failures, + }) @defer.inlineCallbacks def query_local_devices(self, query): @@ -104,7 +140,7 @@ class E2eKeysHandler(object): if not self.is_mine_id(user_id): logger.warning("Request for keys for non-local user %s", user_id) - raise errors.SynapseError(400, "Not a user here") + raise SynapseError(400, "Not a user here") if not device_ids: local_query.append((user_id, None)) @@ -137,3 +173,107 @@ class E2eKeysHandler(object): device_keys_query = query_body.get("device_keys", {}) res = yield self.query_local_devices(device_keys_query) defer.returnValue({"device_keys": res}) + + @defer.inlineCallbacks + def claim_one_time_keys(self, query, timeout): + local_query = [] + remote_queries = {} + + for user_id, device_keys in query.get("one_time_keys", {}).items(): + if self.is_mine_id(user_id): + for device_id, algorithm in device_keys.items(): + local_query.append((user_id, device_id, algorithm)) + else: + domain = get_domain_from_id(user_id) + remote_queries.setdefault(domain, {})[user_id] = device_keys + + results = yield self.store.claim_e2e_one_time_keys(local_query) + + json_result = {} + failures = {} + for user_id, device_keys in results.items(): + for device_id, keys in device_keys.items(): + for key_id, json_bytes in keys.items(): + json_result.setdefault(user_id, {})[device_id] = { + key_id: json.loads(json_bytes) + } + + @defer.inlineCallbacks + def claim_client_keys(destination): + device_keys = remote_queries[destination] + try: + limiter = yield get_retry_limiter( + destination, self.clock, self.store + ) + with limiter: + remote_result = yield self.federation.claim_client_keys( + destination, + {"one_time_keys": device_keys}, + timeout=timeout + ) + for user_id, keys in remote_result["one_time_keys"].items(): + if user_id in device_keys: + json_result[user_id] = keys + except CodeMessageException as e: + failures[destination] = { + "status": e.code, "message": e.message + } + except NotRetryingDestination as e: + failures[destination] = { + "status": 503, "message": "Not ready for retry", + } + + yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(claim_client_keys)(destination) + for destination in remote_queries + ])) + + defer.returnValue({ + "one_time_keys": json_result, + "failures": failures + }) + + @defer.inlineCallbacks + def upload_keys_for_user(self, user_id, device_id, keys): + time_now = self.clock.time_msec() + + # TODO: Validate the JSON to make sure it has the right keys. + device_keys = keys.get("device_keys", None) + if device_keys: + logger.info( + "Updating device_keys for device %r for user %s at %d", + device_id, user_id, time_now + ) + # TODO: Sign the JSON with the server key + yield self.store.set_e2e_device_keys( + user_id, device_id, time_now, + encode_canonical_json(device_keys) + ) + + one_time_keys = keys.get("one_time_keys", None) + if one_time_keys: + logger.info( + "Adding %d one_time_keys for device %r for user %r at %d", + len(one_time_keys), device_id, user_id, time_now + ) + key_list = [] + for key_id, key_json in one_time_keys.items(): + algorithm, key_id = key_id.split(":") + key_list.append(( + algorithm, key_id, encode_canonical_json(key_json) + )) + + yield self.store.add_e2e_one_time_keys( + user_id, device_id, time_now, key_list + ) + + # the device should have been registered already, but it may have been + # deleted due to a race with a DELETE request. Or we may be using an + # old access_token without an associated device_id. Either way, we + # need to double-check the device is registered to avoid ending up with + # keys without a corresponding device. + self.device_handler.check_device_registered(user_id, device_id) + + result = yield self.store.count_e2e_one_time_keys(user_id, device_id) + + defer.returnValue({"one_time_key_counts": result}) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index dc90a5dde4..f7cb3c1bb2 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -832,11 +832,13 @@ class FederationHandler(BaseHandler): new_pdu = event - message_handler = self.hs.get_handlers().message_handler - destinations = yield message_handler.get_joined_hosts_for_room_from_state( - context + users_in_room = yield self.store.get_joined_users_from_context(event, context) + + destinations = set( + get_domain_from_id(user_id) for user_id in users_in_room + if not self.hs.is_mine_id(user_id) ) - destinations = set(destinations) + destinations.discard(origin) logger.debug( @@ -1055,11 +1057,12 @@ class FederationHandler(BaseHandler): new_pdu = event - message_handler = self.hs.get_handlers().message_handler - destinations = yield message_handler.get_joined_hosts_for_room_from_state( - context + users_in_room = yield self.store.get_joined_users_from_context(event, context) + + destinations = set( + get_domain_from_id(user_id) for user_id in users_in_room + if not self.hs.is_mine_id(user_id) ) - destinations = set(destinations) destinations.discard(origin) logger.debug( @@ -1582,10 +1585,12 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) different_auth = event_auth_events - current_state + context.current_state_ids = dict(context.current_state_ids) context.current_state_ids.update({ k: a.event_id for k, a in auth_events.items() if k != event_key }) + context.prev_state_ids = dict(context.prev_state_ids) context.prev_state_ids.update({ k: a.event_id for k, a in auth_events.items() }) @@ -1667,10 +1672,12 @@ class FederationHandler(BaseHandler): # 4. Look at rejects and their proofs. # TODO. + context.current_state_ids = dict(context.current_state_ids) context.current_state_ids.update({ k: a.event_id for k, a in auth_events.items() if k != event_key }) + context.prev_state_ids = dict(context.prev_state_ids) context.prev_state_ids.update({ k: a.event_id for k, a in auth_events.items() }) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 3577db0595..178209a209 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -30,7 +30,6 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.metrics import measure_func -from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -945,7 +944,12 @@ class MessageHandler(BaseHandler): event_stream_id, max_stream_id ) - destinations = yield self.get_joined_hosts_for_room_from_state(context) + users_in_room = yield self.store.get_joined_users_from_context(event, context) + + destinations = [ + get_domain_from_id(user_id) for user_id in users_in_room + if not self.hs.is_mine_id(user_id) + ] @defer.inlineCallbacks def _notify(): @@ -963,39 +967,3 @@ class MessageHandler(BaseHandler): preserve_fn(federation_handler.handle_new_event)( event, destinations=destinations, ) - - def get_joined_hosts_for_room_from_state(self, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - return self._get_joined_hosts_for_room_from_state( - state_group, context.current_state_ids - ) - - @cachedInlineCallbacks(num_args=1, cache_context=True) - def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids, - cache_context): - - # Don't bother getting state for people on the same HS - current_state = yield self.store.get_events([ - e_id for key, e_id in current_state_ids.items() - if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1]) - ]) - - destinations = set() - for e in current_state.itervalues(): - try: - if e.type == EventTypes.Member: - if e.content["membership"] == Membership.JOIN: - destinations.add(get_domain_from_id(e.state_key)) - except SynapseError: - logger.warn( - "Failed to get destination from event %s", e.event_id - ) - - defer.returnValue(destinations) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index cf82a2336e..b047ae2250 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -52,6 +52,11 @@ bump_active_time_counter = metrics.register_counter("bump_active_time") get_updates_counter = metrics.register_counter("get_updates", labels=["type"]) +notify_reason_counter = metrics.register_counter("notify_reason", labels=["reason"]) +state_transition_counter = metrics.register_counter( + "state_transition", labels=["from", "to"] +) + # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them # "currently_active" @@ -212,7 +217,7 @@ class PresenceHandler(object): is some spurious presence changes that will self-correct. """ logger.info( - "Performing _on_shutdown. Persiting %d unpersisted changes", + "Performing _on_shutdown. Persisting %d unpersisted changes", len(self.user_to_current_state) ) @@ -229,7 +234,7 @@ class PresenceHandler(object): may stack up and slow down shutdown times. """ logger.info( - "Performing _persist_unpersisted_changes. Persiting %d unpersisted changes", + "Performing _persist_unpersisted_changes. Persisting %d unpersisted changes", len(self.unpersisted_users_changes) ) @@ -260,6 +265,12 @@ class PresenceHandler(object): to_notify = {} # Changes we want to notify everyone about to_federation_ping = {} # These need sending keep-alives + # Only bother handling the last presence change for each user + new_states_dict = {} + for new_state in new_states: + new_states_dict[new_state.user_id] = new_state + new_state = new_states_dict.values() + for new_state in new_states: user_id = new_state.user_id @@ -614,18 +625,8 @@ class PresenceHandler(object): Args: hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]` """ - now = self.clock.time_msec() for host, states in hosts_to_states.items(): - self.federation.send_edu( - destination=host, - edu_type="m.presence", - content={ - "push": [ - _format_user_presence_state(state, now) - for state in states - ] - } - ) + self.federation.send_presence(host, states) @defer.inlineCallbacks def incoming_presence(self, origin, content): @@ -646,6 +647,13 @@ class PresenceHandler(object): ) continue + if get_domain_from_id(user_id) != origin: + logger.info( + "Got presence update from %r with bad 'user_id': %r", + origin, user_id, + ) + continue + presence_state = push.get("presence", None) if not presence_state: logger.info( @@ -705,13 +713,13 @@ class PresenceHandler(object): defer.returnValue([ { "type": "m.presence", - "content": _format_user_presence_state(state, now), + "content": format_user_presence_state(state, now), } for state in updates ]) else: defer.returnValue([ - _format_user_presence_state(state, now) for state in updates + format_user_presence_state(state, now) for state in updates ]) @defer.inlineCallbacks @@ -939,33 +947,38 @@ class PresenceHandler(object): def should_notify(old_state, new_state): """Decides if a presence state change should be sent to interested parties. """ + if old_state == new_state: + return False + if old_state.status_msg != new_state.status_msg: + notify_reason_counter.inc("status_msg_change") return True - if old_state.state == PresenceState.ONLINE: - if new_state.state != PresenceState.ONLINE: - # Always notify for online -> anything - return True + if old_state.state != new_state.state: + notify_reason_counter.inc("state_change") + state_transition_counter.inc(old_state.state, new_state.state) + return True + if old_state.state == PresenceState.ONLINE: if new_state.currently_active != old_state.currently_active: + notify_reason_counter.inc("current_active_change") return True if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: # Only notify about last active bumps if we're not currently acive - if not (old_state.currently_active and new_state.currently_active): + if not new_state.currently_active: + notify_reason_counter.inc("last_active_change_online") return True elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: # Always notify for a transition where last active gets bumped. - return True - - if old_state.state != new_state.state: + notify_reason_counter.inc("last_active_change_not_online") return True return False -def _format_user_presence_state(state, now): +def format_user_presence_state(state, now): """Convert UserPresenceState to a format that can be sent down to clients and to other servers. """ @@ -1078,7 +1091,7 @@ class PresenceEventSource(object): defer.returnValue(([ { "type": "m.presence", - "content": _format_user_presence_state(s, now), + "content": format_user_presence_state(s, now), } for s in updates.values() if include_offline or s.state != PresenceState.OFFLINE diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 726f7308d2..e536a909d0 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -156,6 +156,7 @@ class ReceiptsHandler(BaseHandler): } }, }, + key=(room_id, receipt_type, user_id), ) @defer.inlineCallbacks diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index bf6b1c1535..cbd26f8f95 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -20,12 +20,10 @@ from ._base import BaseHandler from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken from synapse.api.constants import ( - EventTypes, JoinRules, RoomCreationPreset, Membership, + EventTypes, JoinRules, RoomCreationPreset ) from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.util import stringutils -from synapse.util.async import concurrently_execute -from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client from collections import OrderedDict @@ -36,8 +34,6 @@ import string logger = logging.getLogger(__name__) -REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 - id_server_scheme = "https://" @@ -196,6 +192,11 @@ class RoomCreationHandler(BaseHandler): }, ratelimit=False) + content = {} + is_direct = config.get("is_direct", None) + if is_direct: + content["is_direct"] = is_direct + for invitee in invite_list: yield room_member_handler.update_membership( requester, @@ -203,6 +204,7 @@ class RoomCreationHandler(BaseHandler): room_id, "invite", ratelimit=False, + content=content, ) for invite_3pid in invite_3pid_list: @@ -342,149 +344,6 @@ class RoomCreationHandler(BaseHandler): ) -class RoomListHandler(BaseHandler): - def __init__(self, hs): - super(RoomListHandler, self).__init__(hs) - self.response_cache = ResponseCache(hs) - self.remote_list_request_cache = ResponseCache(hs) - self.remote_list_cache = {} - self.fetch_looping_call = hs.get_clock().looping_call( - self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL - ) - self.fetch_all_remote_lists() - - def get_local_public_room_list(self): - result = self.response_cache.get(()) - if not result: - result = self.response_cache.set((), self._get_public_room_list()) - return result - - @defer.inlineCallbacks - def _get_public_room_list(self): - room_ids = yield self.store.get_public_room_ids() - - results = [] - - @defer.inlineCallbacks - def handle_room(room_id): - current_state = yield self.state_handler.get_current_state(room_id) - - # Double check that this is actually a public room. - join_rules_event = current_state.get((EventTypes.JoinRules, "")) - if join_rules_event: - join_rule = join_rules_event.content.get("join_rule", None) - if join_rule and join_rule != JoinRules.PUBLIC: - defer.returnValue(None) - - result = {"room_id": room_id} - - num_joined_users = len([ - 1 for _, event in current_state.items() - if event.type == EventTypes.Member - and event.membership == Membership.JOIN - ]) - if num_joined_users == 0: - return - - result["num_joined_members"] = num_joined_users - - aliases = yield self.store.get_aliases_for_room(room_id) - if aliases: - result["aliases"] = aliases - - name_event = yield current_state.get((EventTypes.Name, "")) - if name_event: - name = name_event.content.get("name", None) - if name: - result["name"] = name - - topic_event = current_state.get((EventTypes.Topic, "")) - if topic_event: - topic = topic_event.content.get("topic", None) - if topic: - result["topic"] = topic - - canonical_event = current_state.get((EventTypes.CanonicalAlias, "")) - if canonical_event: - canonical_alias = canonical_event.content.get("alias", None) - if canonical_alias: - result["canonical_alias"] = canonical_alias - - visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, "")) - visibility = None - if visibility_event: - visibility = visibility_event.content.get("history_visibility", None) - result["world_readable"] = visibility == "world_readable" - - guest_event = current_state.get((EventTypes.GuestAccess, "")) - guest = None - if guest_event: - guest = guest_event.content.get("guest_access", None) - result["guest_can_join"] = guest == "can_join" - - avatar_event = current_state.get(("m.room.avatar", "")) - if avatar_event: - avatar_url = avatar_event.content.get("url", None) - if avatar_url: - result["avatar_url"] = avatar_url - - results.append(result) - - yield concurrently_execute(handle_room, room_ids, 10) - - # FIXME (erikj): START is no longer a valid value - defer.returnValue({"start": "START", "end": "END", "chunk": results}) - - @defer.inlineCallbacks - def fetch_all_remote_lists(self): - deferred = self.hs.get_replication_layer().get_public_rooms( - self.hs.config.secondary_directory_servers - ) - self.remote_list_request_cache.set((), deferred) - self.remote_list_cache = yield deferred - - @defer.inlineCallbacks - def get_aggregated_public_room_list(self): - """ - Get the public room list from this server and the servers - specified in the secondary_directory_servers config option. - XXX: Pagination... - """ - # We return the results from out cache which is updated by a looping call, - # unless we're missing a cache entry, in which case wait for the result - # of the fetch if there's one in progress. If not, omit that server. - wait = False - for s in self.hs.config.secondary_directory_servers: - if s not in self.remote_list_cache: - logger.warn("No cached room list from %s: waiting for fetch", s) - wait = True - break - - if wait and self.remote_list_request_cache.get(()): - yield self.remote_list_request_cache.get(()) - - public_rooms = yield self.get_local_public_room_list() - - # keep track of which room IDs we've seen so we can de-dup - room_ids = set() - - # tag all the ones in our list with our server name. - # Also add the them to the de-deping set - for room in public_rooms['chunk']: - room["server_name"] = self.hs.hostname - room_ids.add(room["room_id"]) - - # Now add the results from federation - for server_name, server_result in self.remote_list_cache.items(): - for room in server_result["chunk"]: - if room["room_id"] not in room_ids: - room["server_name"] = server_name - public_rooms["chunk"].append(room) - room_ids.add(room["room_id"]) - - defer.returnValue(public_rooms) - - class RoomContextHandler(BaseHandler): @defer.inlineCallbacks def get_event_context(self, user, room_id, event_id, limit, is_guest): diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py new file mode 100644 index 0000000000..5a533682c5 --- /dev/null +++ b/synapse/handlers/room_list.py @@ -0,0 +1,400 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 - 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from ._base import BaseHandler + +from synapse.api.constants import ( + EventTypes, JoinRules, +) +from synapse.util.async import concurrently_execute +from synapse.util.caches.response_cache import ResponseCache + +from collections import namedtuple +from unpaddedbase64 import encode_base64, decode_base64 + +import logging +import msgpack + +logger = logging.getLogger(__name__) + +REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 + + +class RoomListHandler(BaseHandler): + def __init__(self, hs): + super(RoomListHandler, self).__init__(hs) + self.response_cache = ResponseCache(hs) + self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000) + + def get_local_public_room_list(self, limit=None, since_token=None, + search_filter=None): + if search_filter: + # We explicitly don't bother caching searches. + return self._get_public_room_list(limit, since_token, search_filter) + + result = self.response_cache.get((limit, since_token)) + if not result: + result = self.response_cache.set( + (limit, since_token), + self._get_public_room_list(limit, since_token) + ) + return result + + @defer.inlineCallbacks + def _get_public_room_list(self, limit=None, since_token=None, + search_filter=None): + if since_token and since_token != "END": + since_token = RoomListNextBatch.from_token(since_token) + else: + since_token = None + + rooms_to_order_value = {} + rooms_to_num_joined = {} + rooms_to_latest_event_ids = {} + + newly_visible = [] + newly_unpublished = [] + if since_token: + stream_token = since_token.stream_ordering + current_public_id = yield self.store.get_current_public_room_stream_id() + public_room_stream_id = since_token.public_room_stream_id + newly_visible, newly_unpublished = yield self.store.get_public_room_changes( + public_room_stream_id, current_public_id + ) + else: + stream_token = yield self.store.get_room_max_stream_ordering() + public_room_stream_id = yield self.store.get_current_public_room_stream_id() + + room_ids = yield self.store.get_public_room_ids_at_stream_id( + public_room_stream_id + ) + + # We want to return rooms in a particular order: the number of joined + # users. We then arbitrarily use the room_id as a tie breaker. + + @defer.inlineCallbacks + def get_order_for_room(room_id): + latest_event_ids = rooms_to_latest_event_ids.get(room_id, None) + if not latest_event_ids: + latest_event_ids = yield self.store.get_forward_extremeties_for_room( + room_id, stream_token + ) + rooms_to_latest_event_ids[room_id] = latest_event_ids + + if not latest_event_ids: + return + + joined_users = yield self.state_handler.get_current_user_in_room( + room_id, latest_event_ids, + ) + num_joined_users = len(joined_users) + rooms_to_num_joined[room_id] = num_joined_users + + if num_joined_users == 0: + return + + # We want larger rooms to be first, hence negating num_joined_users + rooms_to_order_value[room_id] = (-num_joined_users, room_id) + + yield concurrently_execute(get_order_for_room, room_ids, 10) + + sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1]) + sorted_rooms = [room_id for room_id, _ in sorted_entries] + + # `sorted_rooms` should now be a list of all public room ids that is + # stable across pagination. Therefore, we can use indices into this + # list as our pagination tokens. + + # Filter out rooms that we don't want to return + rooms_to_scan = [ + r for r in sorted_rooms + if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0 + ] + + if since_token: + # Filter out rooms we've already returned previously + # `since_token.current_limit` is the index of the last room we + # sent down, so we exclude it and everything before/after it. + if since_token.direction_is_forward: + rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:] + else: + rooms_to_scan = rooms_to_scan[:since_token.current_limit] + rooms_to_scan.reverse() + + # Actually generate the entries. _generate_room_entry will append to + # chunk but will stop if len(chunk) > limit + chunk = [] + if limit and not search_filter: + step = limit + 1 + for i in xrange(0, len(rooms_to_scan), step): + # We iterate here because the vast majority of cases we'll stop + # at first iteration, but occaisonally _generate_room_entry + # won't append to the chunk and so we need to loop again. + # We don't want to scan over the entire range either as that + # would potentially waste a lot of work. + yield concurrently_execute( + lambda r: self._generate_room_entry( + r, rooms_to_num_joined[r], + chunk, limit, search_filter + ), + rooms_to_scan[i:i + step], 10 + ) + if len(chunk) >= limit + 1: + break + else: + yield concurrently_execute( + lambda r: self._generate_room_entry( + r, rooms_to_num_joined[r], + chunk, limit, search_filter + ), + rooms_to_scan, 5 + ) + + chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"])) + + # Work out the new limit of the batch for pagination, or None if we + # know there are no more results that would be returned. + # i.e., [since_token.current_limit..new_limit] is the batch of rooms + # we've returned (or the reverse if we paginated backwards) + # We tried to pull out limit + 1 rooms above, so if we have <= limit + # then we know there are no more results to return + new_limit = None + if chunk and (not limit or len(chunk) > limit): + + if not since_token or since_token.direction_is_forward: + if limit: + chunk = chunk[:limit] + last_room_id = chunk[-1]["room_id"] + else: + if limit: + chunk = chunk[-limit:] + last_room_id = chunk[0]["room_id"] + + new_limit = sorted_rooms.index(last_room_id) + + results = { + "chunk": chunk, + } + + if since_token: + results["new_rooms"] = bool(newly_visible) + + if not since_token or since_token.direction_is_forward: + if new_limit is not None: + results["next_batch"] = RoomListNextBatch( + stream_ordering=stream_token, + public_room_stream_id=public_room_stream_id, + current_limit=new_limit, + direction_is_forward=True, + ).to_token() + + if since_token: + results["prev_batch"] = since_token.copy_and_replace( + direction_is_forward=False, + current_limit=since_token.current_limit + 1, + ).to_token() + else: + if new_limit is not None: + results["prev_batch"] = RoomListNextBatch( + stream_ordering=stream_token, + public_room_stream_id=public_room_stream_id, + current_limit=new_limit, + direction_is_forward=False, + ).to_token() + + if since_token: + results["next_batch"] = since_token.copy_and_replace( + direction_is_forward=True, + current_limit=since_token.current_limit - 1, + ).to_token() + + defer.returnValue(results) + + @defer.inlineCallbacks + def _generate_room_entry(self, room_id, num_joined_users, chunk, limit, + search_filter): + if limit and len(chunk) > limit + 1: + # We've already got enough, so lets just drop it. + return + + result = { + "room_id": room_id, + "num_joined_members": num_joined_users, + } + + current_state_ids = yield self.state_handler.get_current_state_ids(room_id) + + event_map = yield self.store.get_events([ + event_id for key, event_id in current_state_ids.items() + if key[0] in ( + EventTypes.JoinRules, + EventTypes.Name, + EventTypes.Topic, + EventTypes.CanonicalAlias, + EventTypes.RoomHistoryVisibility, + EventTypes.GuestAccess, + "m.room.avatar", + ) + ]) + + current_state = { + (ev.type, ev.state_key): ev + for ev in event_map.values() + } + + # Double check that this is actually a public room. + join_rules_event = current_state.get((EventTypes.JoinRules, "")) + if join_rules_event: + join_rule = join_rules_event.content.get("join_rule", None) + if join_rule and join_rule != JoinRules.PUBLIC: + defer.returnValue(None) + + aliases = yield self.store.get_aliases_for_room(room_id) + if aliases: + result["aliases"] = aliases + + name_event = yield current_state.get((EventTypes.Name, "")) + if name_event: + name = name_event.content.get("name", None) + if name: + result["name"] = name + + topic_event = current_state.get((EventTypes.Topic, "")) + if topic_event: + topic = topic_event.content.get("topic", None) + if topic: + result["topic"] = topic + + canonical_event = current_state.get((EventTypes.CanonicalAlias, "")) + if canonical_event: + canonical_alias = canonical_event.content.get("alias", None) + if canonical_alias: + result["canonical_alias"] = canonical_alias + + visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, "")) + visibility = None + if visibility_event: + visibility = visibility_event.content.get("history_visibility", None) + result["world_readable"] = visibility == "world_readable" + + guest_event = current_state.get((EventTypes.GuestAccess, "")) + guest = None + if guest_event: + guest = guest_event.content.get("guest_access", None) + result["guest_can_join"] = guest == "can_join" + + avatar_event = current_state.get(("m.room.avatar", "")) + if avatar_event: + avatar_url = avatar_event.content.get("url", None) + if avatar_url: + result["avatar_url"] = avatar_url + + if _matches_room_entry(result, search_filter): + chunk.append(result) + + @defer.inlineCallbacks + def get_remote_public_room_list(self, server_name, limit=None, since_token=None, + search_filter=None): + if search_filter: + # We currently don't support searching across federation, so we have + # to do it manually without pagination + limit = None + since_token = None + + res = yield self._get_remote_list_cached( + server_name, limit=limit, since_token=since_token, + ) + + if search_filter: + res = {"chunk": [ + entry + for entry in list(res.get("chunk", [])) + if _matches_room_entry(entry, search_filter) + ]} + + defer.returnValue(res) + + def _get_remote_list_cached(self, server_name, limit=None, since_token=None, + search_filter=None): + repl_layer = self.hs.get_replication_layer() + if search_filter: + # We can't cache when asking for search + return repl_layer.get_public_rooms( + server_name, limit=limit, since_token=since_token, + search_filter=search_filter, + ) + + result = self.remote_response_cache.get((server_name, limit, since_token)) + if not result: + result = self.remote_response_cache.set( + (server_name, limit, since_token), + repl_layer.get_public_rooms( + server_name, limit=limit, since_token=since_token, + search_filter=search_filter, + ) + ) + return result + + +class RoomListNextBatch(namedtuple("RoomListNextBatch", ( + "stream_ordering", # stream_ordering of the first public room list + "public_room_stream_id", # public room stream id for first public room list + "current_limit", # The number of previous rooms returned + "direction_is_forward", # Bool if this is a next_batch, false if prev_batch +))): + + KEY_DICT = { + "stream_ordering": "s", + "public_room_stream_id": "p", + "current_limit": "n", + "direction_is_forward": "d", + } + + REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()} + + @classmethod + def from_token(cls, token): + return RoomListNextBatch(**{ + cls.REVERSE_KEY_DICT[key]: val + for key, val in msgpack.loads(decode_base64(token)).items() + }) + + def to_token(self): + return encode_base64(msgpack.dumps({ + self.KEY_DICT[key]: val + for key, val in self._asdict().items() + })) + + def copy_and_replace(self, **kwds): + return self._replace( + **kwds + ) + + +def _matches_room_entry(room_entry, search_filter): + if search_filter and search_filter.get("generic_search_term", None): + generic_search_term = search_filter["generic_search_term"].upper() + if generic_search_term in room_entry.get("name", "").upper(): + return True + elif generic_search_term in room_entry.get("topic", "").upper(): + return True + elif generic_search_term in room_entry.get("canonical_alias", "").upper(): + return True + else: + return True + + return False diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0b530b9034..0548b81c34 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -187,6 +187,7 @@ class TypingHandler(object): "user_id": user_id, "typing": typing, }, + key=(room_id, user_id), )) yield preserve_context_over_deferred( @@ -199,7 +200,14 @@ class TypingHandler(object): user_id = content["user_id"] # Check that the string is a valid user id - UserID.from_string(user_id) + user = UserID.from_string(user_id) + + if user.domain != origin: + logger.info( + "Got typing update from %r with bad 'user_id': %r", + origin, user_id, + ) + return users = yield self.state.get_current_user_in_room(room_id) domains = set(get_domain_from_id(u) for u in users) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index f93093dd85..d0556ae347 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -246,7 +246,7 @@ class MatrixFederationHttpClient(object): @defer.inlineCallbacks def put_json(self, destination, path, data={}, json_data_callback=None, - long_retries=False): + long_retries=False, timeout=None): """ Sends the specifed json data using PUT Args: @@ -259,6 +259,8 @@ class MatrixFederationHttpClient(object): use as the request body. long_retries (bool): A boolean that indicates whether we should retry for a short or long time. + timeout(int): How long to try (in ms) the destination for before + giving up. None indicates no timeout. Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result @@ -285,6 +287,7 @@ class MatrixFederationHttpClient(object): body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, long_retries=long_retries, + timeout=timeout, ) if 200 <= response.code < 300: @@ -300,7 +303,8 @@ class MatrixFederationHttpClient(object): defer.returnValue(json.loads(body)) @defer.inlineCallbacks - def post_json(self, destination, path, data={}, long_retries=True): + def post_json(self, destination, path, data={}, long_retries=True, + timeout=None): """ Sends the specifed json data using POST Args: @@ -311,6 +315,8 @@ class MatrixFederationHttpClient(object): the request body. This will be encoded as JSON. long_retries (bool): A boolean that indicates whether we should retry for a short or long time. + timeout(int): How long to try (in ms) the destination for before + giving up. None indicates no timeout. Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result @@ -331,6 +337,7 @@ class MatrixFederationHttpClient(object): body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, long_retries=True, + timeout=timeout, ) if 200 <= response.code < 300: diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index e41afeab8e..9346386238 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -41,9 +41,13 @@ def parse_integer(request, name, default=None, required=False): SynapseError: if the parameter is absent and required, or if the parameter is present and not an integer. """ - if name in request.args: + return parse_integer_from_args(request.args, name, default, required) + + +def parse_integer_from_args(args, name, default=None, required=False): + if name in args: try: - return int(request.args[name][0]) + return int(args[name][0]) except: message = "Query parameter %r must be an integer" % (name,) raise SynapseError(400, message) @@ -116,9 +120,15 @@ def parse_string(request, name, default=None, required=False, parameter is present, must be one of a list of allowed values and is not one of those allowed values. """ + return parse_string_from_args( + request.args, name, default, required, allowed_values, param_type, + ) - if name in request.args: - value = request.args[name][0] + +def parse_string_from_args(args, name, default=None, required=False, + allowed_values=None, param_type="string"): + if name in args: + value = args[name][0] if allowed_values is not None and value not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( name, ", ".join(repr(v) for v in allowed_values) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index edb00ed206..85effdfa46 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -263,6 +263,8 @@ BASE_APPEND_UNDERRIDE_RULES = [ } ] }, + # XXX: once m.direct is standardised everywhere, we should use it to detect + # a DM from the user's perspective rather than this heuristic. { 'rule_id': 'global/underride/.m.rule.room_one_to_one', 'conditions': [ @@ -289,6 +291,34 @@ BASE_APPEND_UNDERRIDE_RULES = [ } ] }, + # XXX: this is going to fire for events which aren't m.room.messages + # but are encrypted (e.g. m.call.*)... + { + 'rule_id': 'global/underride/.m.rule.encrypted_room_one_to_one', + 'conditions': [ + { + 'kind': 'room_member_count', + 'is': '2', + '_id': 'member_count', + }, + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.encrypted', + '_id': '_encrypted', + } + ], + 'actions': [ + 'notify', + { + 'set_tweak': 'sound', + 'value': 'default' + }, { + 'set_tweak': 'highlight', + 'value': False + } + ] + }, { 'rule_id': 'global/underride/.m.rule.message', 'conditions': [ @@ -305,6 +335,25 @@ BASE_APPEND_UNDERRIDE_RULES = [ 'value': False } ] + }, + # XXX: this is going to fire for events which aren't m.room.messages + # but are encrypted (e.g. m.call.*)... + { + 'rule_id': 'global/underride/.m.rule.encrypted', + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'type', + 'pattern': 'm.room.encrypted', + '_id': '_encrypted', + } + ], + 'actions': [ + 'notify', { + 'set_tweak': 'highlight', + 'value': False + } + ] } ] diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index f1bbe57dcb..be55598c43 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -27,15 +27,6 @@ logger = logging.getLogger(__name__) @defer.inlineCallbacks -def _get_rules(room_id, user_ids, store): - rules_by_user = yield store.bulk_get_push_rules(user_ids) - - rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} - - defer.returnValue(rules_by_user) - - -@defer.inlineCallbacks def evaluator_for_event(event, hs, store, context): rules_by_user = yield store.bulk_get_push_rules_for_room( event, context @@ -48,6 +39,7 @@ def evaluator_for_event(event, hs, store, context): if invited_user and hs.is_mine_id(invited_user): has_pusher = yield store.user_has_pusher(invited_user) if has_pusher: + rules_by_user = dict(rules_by_user) rules_by_user[invited_user] = yield store.get_push_rules_for_user( invited_user ) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 86e3d89154..b9e41770ee 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -36,6 +36,7 @@ REQUIREMENTS = { "blist": ["blist"], "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pymacaroons-pynacl": ["pymacaroons"], + "msgpack-python>=0.3.0": ["msgpack"], } CONDITIONAL_REQUIREMENTS = { "web_client": { diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 1ed9034bcb..9aab3ce23c 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -42,6 +42,7 @@ STREAM_NAMES = ( ("pushers",), ("caches",), ("to_device",), + ("public_rooms",), ) @@ -131,6 +132,7 @@ class ReplicationResource(Resource): push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() pushers_token = self.store.get_pushers_stream_token() caches_token = self.store.get_cache_stream_token() + public_rooms_token = self.store.get_current_public_room_stream_id() defer.returnValue(_ReplicationToken( room_stream_token, @@ -144,6 +146,7 @@ class ReplicationResource(Resource): 0, # State stream is no longer a thing caches_token, int(stream_token.to_device_key), + int(public_rooms_token), )) @request_handler() @@ -181,7 +184,7 @@ class ReplicationResource(Resource): def replicate(self, request_streams, limit): writer = _Writer() current_token = yield self.current_replication_token() - logger.info("Replicating up to %r", current_token) + logger.debug("Replicating up to %r", current_token) yield self.account_data(writer, current_token, limit, request_streams) yield self.events(writer, current_token, limit, request_streams) @@ -193,9 +196,10 @@ class ReplicationResource(Resource): yield self.pushers(writer, current_token, limit, request_streams) yield self.caches(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams) + yield self.public_rooms(writer, current_token, limit, request_streams) self.streams(writer, current_token, request_streams) - logger.info("Replicated %d rows", writer.total) + logger.debug("Replicated %d rows", writer.total) defer.returnValue(writer.finish()) def streams(self, writer, current_token, request_streams): @@ -274,11 +278,18 @@ class ReplicationResource(Resource): @defer.inlineCallbacks def typing(self, writer, current_token, request_streams): - current_position = current_token.presence + current_position = current_token.typing request_typing = request_streams.get("typing") if request_typing is not None: + # If they have a higher token than current max, we can assume that + # they had been talking to a previous instance of the master. Since + # we reset the token on restart, the best (but hacky) thing we can + # do is to simply resend down all the typing notifications. + if request_typing > current_position: + request_typing = 0 + typing_rows = yield self.typing_handler.get_all_typing_updates( request_typing, current_position ) @@ -393,6 +404,20 @@ class ReplicationResource(Resource): "position", "user_id", "device_id", "message_json" )) + @defer.inlineCallbacks + def public_rooms(self, writer, current_token, limit, request_streams): + current_position = current_token.public_rooms + + public_rooms = request_streams.get("public_rooms") + + if public_rooms is not None: + public_rooms_rows = yield self.store.get_all_new_public_rooms( + public_rooms, current_position, limit + ) + writer.write_header_and_rows("public_rooms", public_rooms_rows, ( + "position", "room_id", "visibility" + )) + class _Writer(object): """Writes the streams as a JSON object as the response to the request""" @@ -421,7 +446,7 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", - "push_rules", "pushers", "state", "caches", "to_device", + "push_rules", "pushers", "state", "caches", "to_device", "public_rooms", ))): __slots__ = [] diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 64d8eb2af1..3bfd5e8213 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -16,13 +16,18 @@ from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker from synapse.storage import DataStore +from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( - db_conn, "device_inbox", "stream_id", + db_conn, "device_max_stream_id", "stream_id", + ) + self._device_inbox_stream_cache = StreamChangeCache( + "DeviceInboxStreamChangeCache", + self._device_inbox_id_gen.get_current_token() ) get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ @@ -38,5 +43,11 @@ class SlavedDeviceInboxStore(BaseSlavedStore): stream = result.get("to_device") if stream: self._device_inbox_id_gen.advance(int(stream["position"])) + for row in stream["rows"]: + stream_id = row[0] + user_id = row[1] + self._device_inbox_stream_cache.entity_has_changed( + user_id, stream_id + ) return super(SlavedDeviceInboxStore, self).process_replication(result) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index cbebd5b2f7..0c26e96e98 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -61,6 +61,9 @@ class SlavedEventStore(BaseSlavedStore): "MembershipStreamChangeCache", events_max, ) + self.stream_ordering_month_ago = 0 + self._stream_order_on_start = self.get_room_max_stream_ordering() + # Cached functions can't be accessed through a class instance so we need # to reach inside the __dict__ to extract them. get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] @@ -86,6 +89,9 @@ class SlavedEventStore(BaseSlavedStore): _get_state_groups_from_groups = ( StateStore.__dict__["_get_state_groups_from_groups"] ) + _get_state_groups_from_groups_txn = ( + DataStore._get_state_groups_from_groups_txn.__func__ + ) _get_state_group_from_group = ( StateStore.__dict__["_get_state_group_from_group"] ) @@ -165,6 +171,15 @@ class SlavedEventStore(BaseSlavedStore): get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__ _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__ + get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__ + + get_forward_extremeties_for_room = ( + DataStore.get_forward_extremeties_for_room.__func__ + ) + _get_forward_extremeties_for_room = ( + EventFederationStore.__dict__["_get_forward_extremeties_for_room"] + ) + def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() result["events"] = self._stream_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index d5bb0f98ea..23c613863f 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -15,7 +15,39 @@ from ._base import BaseSlavedStore from synapse.storage import DataStore +from ._slaved_id_tracker import SlavedIdTracker class RoomStore(BaseSlavedStore): + def __init__(self, db_conn, hs): + super(RoomStore, self).__init__(db_conn, hs) + self._public_room_id_gen = SlavedIdTracker( + db_conn, "public_room_list_stream", "stream_id" + ) + get_public_room_ids = DataStore.get_public_room_ids.__func__ + get_current_public_room_stream_id = ( + DataStore.get_current_public_room_stream_id.__func__ + ) + get_public_room_ids_at_stream_id = ( + DataStore.get_public_room_ids_at_stream_id.__func__ + ) + get_public_room_ids_at_stream_id_txn = ( + DataStore.get_public_room_ids_at_stream_id_txn.__func__ + ) + get_published_at_stream_id_txn = ( + DataStore.get_published_at_stream_id_txn.__func__ + ) + get_public_room_changes = DataStore.get_public_room_changes.__func__ + + def stream_positions(self): + result = super(RoomStore, self).stream_positions() + result["public_rooms"] = self._public_room_id_gen.get_current_token() + return result + + def process_replication(self, result): + stream = result.get("public_rooms") + if stream: + self._public_room_id_gen.advance(int(stream["position"])) + + return super(RoomStore, self).process_replication(result) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 6c0eec8fb3..345018a8fc 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -318,7 +318,7 @@ class CasRedirectServlet(ClientV1RestServlet): service_param = urllib.urlencode({ "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) }) - request.redirect("%s?%s" % (self.cas_server_url, service_param)) + request.redirect("%s/login?%s" % (self.cas_server_url, service_param)) finish_request(request) @@ -385,7 +385,7 @@ class CasTicketServlet(ClientV1RestServlet): def parse_cas_response(self, cas_response_body): user = None - attributes = None + attributes = {} try: root = ET.fromstring(cas_response_body) if not root.tag.endswith("serviceResponse"): @@ -395,7 +395,6 @@ class CasTicketServlet(ClientV1RestServlet): if child.tag.endswith("user"): user = child.text if child.tag.endswith("attributes"): - attributes = {} for attribute in child: # ElementTree library expands the namespace in # attribute tags to the full URL of the namespace. @@ -407,8 +406,6 @@ class CasTicketServlet(ClientV1RestServlet): attributes[tag] = attribute.text if user is None: raise Exception("CAS response does not contain user") - if attributes is None: - raise Exception("CAS response does not contain attributes") except Exception: logger.error("Error parsing CAS response", exc_info=1) raise LoginError(401, "Invalid CAS response", diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 9bff02ee4e..1358d0acab 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from synapse.api.errors import AuthError, Codes +from synapse.api.auth import get_access_token_from_request from .base import ClientV1RestServlet, client_path_patterns @@ -37,13 +37,7 @@ class LogoutRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - try: - access_token = request.args["access_token"][0] - except KeyError: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", - errcode=Codes.MISSING_TOKEN - ) + access_token = get_access_token_from_request(request) yield self.store.delete_access_token(access_token) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 71d58c8e8d..3046da7aec 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.api.constants import LoginType +from synapse.api.auth import get_access_token_from_request from .base import ClientV1RestServlet, client_path_patterns import synapse.util.stringutils as stringutils from synapse.http.servlet import parse_json_object_from_request @@ -296,12 +297,11 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_app_service(self, request, register_json, session): - if "access_token" not in request.args: - raise SynapseError(400, "Expected application service token.") + as_token = get_access_token_from_request(request) + if "user" not in register_json: raise SynapseError(400, "Expected 'user' key.") - as_token = request.args["access_token"][0] user_localpart = register_json["user"].encode("utf-8") handler = self.handlers.registration_handler @@ -390,11 +390,9 @@ class CreateUserRestServlet(ClientV1RestServlet): def on_POST(self, request): user_json = parse_json_object_from_request(request) - if "access_token" not in request.args: - raise SynapseError(400, "Expected application service token.") - + access_token = get_access_token_from_request(request) app_service = yield self.store.get_app_service_by_token( - request.args["access_token"][0] + access_token ) if not app_service: raise SynapseError(403, "Invalid application service token.") diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 0d81757010..45287bf05b 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -22,8 +22,10 @@ from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership from synapse.api.filtering import Filter from synapse.types import UserID, RoomID, RoomAlias -from synapse.events.utils import serialize_event -from synapse.http.servlet import parse_json_object_from_request +from synapse.events.utils import serialize_event, format_event_for_client_v2 +from synapse.http.servlet import ( + parse_json_object_from_request, parse_string, parse_integer +) import logging import urllib @@ -120,6 +122,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): requester = yield self.auth.get_user_by_req(request, allow_guest=True) + format = parse_string(request, "format", default="content", + allowed_values=["content", "event"]) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( @@ -134,7 +138,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet): raise SynapseError( 404, "Event not found.", errcode=Codes.NOT_FOUND ) - defer.returnValue((200, data.get_dict()["content"])) + + if format == "event": + event = format_event_for_client_v2(data.get_dict()) + defer.returnValue((200, event)) + elif format == "content": + defer.returnValue((200, data.get_dict()["content"])) @defer.inlineCallbacks def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): @@ -295,15 +304,64 @@ class PublicRoomListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): + server = parse_string(request, "server", default=None) + try: - yield self.auth.get_user_by_req(request) - except AuthError: - # This endpoint isn't authed, but its useful to know who's hitting - # it if they *do* supply an access token - pass + yield self.auth.get_user_by_req(request, allow_guest=True) + except AuthError as e: + # We allow people to not be authed if they're just looking at our + # room list, but require auth when we proxy the request. + # In both cases we call the auth function, as that has the side + # effect of logging who issued this request if an access token was + # provided. + if server: + raise e + else: + pass + + limit = parse_integer(request, "limit", 0) + since_token = parse_string(request, "since", None) handler = self.hs.get_room_list_handler() - data = yield handler.get_aggregated_public_room_list() + if server: + data = yield handler.get_remote_public_room_list( + server, + limit=limit, + since_token=since_token, + ) + else: + data = yield handler.get_local_public_room_list( + limit=limit, + since_token=since_token, + ) + + defer.returnValue((200, data)) + + @defer.inlineCallbacks + def on_POST(self, request): + yield self.auth.get_user_by_req(request, allow_guest=True) + + server = parse_string(request, "server", default=None) + content = parse_json_object_from_request(request) + + limit = int(content.get("limit", 100)) + since_token = content.get("since", None) + search_filter = content.get("filter", None) + + handler = self.hs.get_room_list_handler() + if server: + data = yield handler.get_remote_public_room_list( + server, + limit=limit, + since_token=since_token, + search_filter=search_filter, + ) + else: + data = yield handler.get_local_public_room_list( + limit=limit, + since_token=since_token, + search_filter=search_filter, + ) defer.returnValue((200, data)) diff --git a/synapse/rest/client/v1/transactions.py b/synapse/rest/client/v1/transactions.py index bdccf464a5..2f2c9d0881 100644 --- a/synapse/rest/client/v1/transactions.py +++ b/synapse/rest/client/v1/transactions.py @@ -17,6 +17,8 @@ to ensure idempotency when performing PUTs using the REST API.""" import logging +from synapse.api.auth import get_access_token_from_request + logger = logging.getLogger(__name__) @@ -90,6 +92,6 @@ class HttpTransactionStore(object): return response def _get_key(self, request): - token = request.args["access_token"][0] + token = get_access_token_from_request(request) path_without_txn_id = request.path.rsplit("/", 1)[0] return path_without_txn_id + "/" + token diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index c5ff16adf3..f185f9a774 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -15,15 +15,12 @@ import logging -import simplejson as json -from canonicaljson import encode_canonical_json from twisted.internet import defer -import synapse.api.errors -import synapse.server -import synapse.types -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.types import UserID +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, parse_json_object_from_request, parse_integer +) from ._base import client_v2_patterns logger = logging.getLogger(__name__) @@ -63,17 +60,13 @@ class KeyUploadServlet(RestServlet): hs (synapse.server.HomeServer): server """ super(KeyUploadServlet, self).__init__() - self.store = hs.get_datastore() - self.clock = hs.get_clock() self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + self.e2e_keys_handler = hs.get_e2e_keys_handler() @defer.inlineCallbacks def on_POST(self, request, device_id): requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) if device_id is not None: @@ -88,52 +81,15 @@ class KeyUploadServlet(RestServlet): device_id = requester.device_id if device_id is None: - raise synapse.api.errors.SynapseError( + raise SynapseError( 400, "To upload keys, you must pass device_id when authenticating" ) - time_now = self.clock.time_msec() - - # TODO: Validate the JSON to make sure it has the right keys. - device_keys = body.get("device_keys", None) - if device_keys: - logger.info( - "Updating device_keys for device %r for user %s at %d", - device_id, user_id, time_now - ) - # TODO: Sign the JSON with the server key - yield self.store.set_e2e_device_keys( - user_id, device_id, time_now, - encode_canonical_json(device_keys) - ) - - one_time_keys = body.get("one_time_keys", None) - if one_time_keys: - logger.info( - "Adding %d one_time_keys for device %r for user %r at %d", - len(one_time_keys), device_id, user_id, time_now - ) - key_list = [] - for key_id, key_json in one_time_keys.items(): - algorithm, key_id = key_id.split(":") - key_list.append(( - algorithm, key_id, encode_canonical_json(key_json) - )) - - yield self.store.add_e2e_one_time_keys( - user_id, device_id, time_now, key_list - ) - - # the device should have been registered already, but it may have been - # deleted due to a race with a DELETE request. Or we may be using an - # old access_token without an associated device_id. Either way, we - # need to double-check the device is registered to avoid ending up with - # keys without a corresponding device. - self.device_handler.check_device_registered(user_id, device_id) - - result = yield self.store.count_e2e_one_time_keys(user_id, device_id) - defer.returnValue((200, {"one_time_key_counts": result})) + result = yield self.e2e_keys_handler.upload_keys_for_user( + user_id, device_id, body + ) + defer.returnValue((200, result)) class KeyQueryServlet(RestServlet): @@ -195,20 +151,23 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id, device_id): yield self.auth.get_user_by_req(request) + timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.query_devices(body) - defer.returnValue(result) + result = yield self.e2e_keys_handler.query_devices(body, timeout) + defer.returnValue((200, result)) @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): requester = yield self.auth.get_user_by_req(request) + timeout = parse_integer(request, "timeout", 10 * 1000) auth_user_id = requester.user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] result = yield self.e2e_keys_handler.query_devices( - {"device_keys": {user_id: device_ids}} + {"device_keys": {user_id: device_ids}}, + timeout, ) - defer.returnValue(result) + defer.returnValue((200, result)) class OneTimeKeyServlet(RestServlet): @@ -240,59 +199,29 @@ class OneTimeKeyServlet(RestServlet): def __init__(self, hs): super(OneTimeKeyServlet, self).__init__() - self.store = hs.get_datastore() self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.federation = hs.get_replication_layer() - self.is_mine = hs.is_mine + self.e2e_keys_handler = hs.get_e2e_keys_handler() @defer.inlineCallbacks def on_GET(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) - result = yield self.handle_request( - {"one_time_keys": {user_id: {device_id: algorithm}}} + timeout = parse_integer(request, "timeout", 10 * 1000) + result = yield self.e2e_keys_handler.claim_one_time_keys( + {"one_time_keys": {user_id: {device_id: algorithm}}}, + timeout, ) - defer.returnValue(result) + defer.returnValue((200, result)) @defer.inlineCallbacks def on_POST(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) + timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.handle_request(body) - defer.returnValue(result) - - @defer.inlineCallbacks - def handle_request(self, body): - local_query = [] - remote_queries = {} - for user_id, device_keys in body.get("one_time_keys", {}).items(): - user = UserID.from_string(user_id) - if self.is_mine(user): - for device_id, algorithm in device_keys.items(): - local_query.append((user_id, device_id, algorithm)) - else: - remote_queries.setdefault(user.domain, {})[user_id] = ( - device_keys - ) - results = yield self.store.claim_e2e_one_time_keys(local_query) - - json_result = {} - for user_id, device_keys in results.items(): - for device_id, keys in device_keys.items(): - for key_id, json_bytes in keys.items(): - json_result.setdefault(user_id, {})[device_id] = { - key_id: json.loads(json_bytes) - } - - for destination, device_keys in remote_queries.items(): - remote_result = yield self.federation.claim_client_keys( - destination, {"one_time_keys": device_keys} - ) - for user_id, keys in remote_result["one_time_keys"].items(): - if user_id in device_keys: - json_result[user_id] = keys - - defer.returnValue((200, {"one_time_keys": json_result})) + result = yield self.e2e_keys_handler.claim_one_time_keys( + body, + timeout, + ) + defer.returnValue((200, result)) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index f1a48acf07..fd2a3d69d4 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -45,11 +45,12 @@ class NotificationsServlet(RestServlet): from_token = parse_string(request, "from", required=False) limit = parse_integer(request, "limit", default=50) + only = parse_string(request, "only", required=False) limit = min(limit, 500) push_actions = yield self.store.get_push_actions_for_user( - user_id, from_token, limit + user_id, from_token, limit, only_highlight=(only == "highlight") ) receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 2121bd75ea..68d18a9b82 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -15,6 +15,7 @@ from twisted.internet import defer +from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -131,7 +132,7 @@ class RegisterRestServlet(RestServlet): desired_username = body['username'] appservice = None - if 'access_token' in request.args: + if has_access_token(request): appservice = yield self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes and shared secret auth which @@ -143,10 +144,11 @@ class RegisterRestServlet(RestServlet): # 'user' key not 'username'). Since this is a new addition, we'll # fallback to 'username' if they gave one. desired_username = body.get("user", desired_username) + access_token = get_access_token_from_request(request) if isinstance(desired_username, basestring): result = yield self._do_appservice_registration( - desired_username, request.args["access_token"][0], body + desired_username, access_token, body ) defer.returnValue((200, result)) # we throw for non 200 responses return diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index 9c10a99acf..5975164b37 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -16,10 +16,11 @@ import logging from twisted.internet import defer -from synapse.http.servlet import parse_json_object_from_request from synapse.http import servlet +from synapse.http.servlet import parse_json_object_from_request from synapse.rest.client.v1.transactions import HttpTransactionStore + from ._base import client_v2_patterns logger = logging.getLogger(__name__) @@ -39,10 +40,8 @@ class SendToDeviceRestServlet(servlet.RestServlet): super(SendToDeviceRestServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self.is_mine_id = hs.is_mine_id self.txns = HttpTransactionStore() + self.device_message_handler = hs.get_device_message_handler() @defer.inlineCallbacks def on_PUT(self, request, message_type, txn_id): @@ -57,28 +56,10 @@ class SendToDeviceRestServlet(servlet.RestServlet): content = parse_json_object_from_request(request) - # TODO: Prod the notifier to wake up sync streams. - # TODO: Implement replication for the messages. - # TODO: Send the messages to remote servers if needed. - - local_messages = {} - for user_id, by_device in content["messages"].items(): - if self.is_mine_id(user_id): - messages_by_device = { - device_id: { - "content": message_content, - "type": message_type, - "sender": requester.user.to_string(), - } - for device_id, message_content in by_device.items() - } - if messages_by_device: - local_messages[user_id] = messages_by_device - - stream_id = yield self.store.add_messages_to_device_inbox(local_messages) - - self.notifier.on_new_event( - "to_device_key", stream_id, users=local_messages.keys() + sender_user_id = requester.user.to_string() + + yield self.device_message_handler.send_device_message( + sender_user_id, message_type, content["messages"] ) response = (200, {}) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 4f6f1a7e17..31f94bc6e9 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -42,6 +42,29 @@ class ThirdPartyProtocolsServlet(RestServlet): defer.returnValue((200, protocols)) +class ThirdPartyProtocolServlet(RestServlet): + PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$", + releases=()) + + def __init__(self, hs): + super(ThirdPartyProtocolServlet, self).__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + @defer.inlineCallbacks + def on_GET(self, request, protocol): + yield self.auth.get_user_by_req(request) + + protocols = yield self.appservice_handler.get_3pe_protocols( + only_protocol=protocol, + ) + if protocol in protocols: + defer.returnValue((200, protocols[protocol])) + else: + defer.returnValue((404, {"error": "Unknown protocol"})) + + class ThirdPartyUserServlet(RestServlet): PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$", releases=()) @@ -57,7 +80,7 @@ class ThirdPartyUserServlet(RestServlet): yield self.auth.get_user_by_req(request) fields = request.args - del fields["access_token"] + fields.pop("access_token", None) results = yield self.appservice_handler.query_3pe( ThirdPartyEntityKind.USER, protocol, fields @@ -81,7 +104,7 @@ class ThirdPartyLocationServlet(RestServlet): yield self.auth.get_user_by_req(request) fields = request.args - del fields["access_token"] + fields.pop("access_token", None) results = yield self.appservice_handler.query_3pe( ThirdPartyEntityKind.LOCATION, protocol, fields @@ -92,5 +115,6 @@ class ThirdPartyLocationServlet(RestServlet): def register_servlets(hs, http_server): ThirdPartyProtocolsServlet(hs).register(http_server) + ThirdPartyProtocolServlet(hs).register(http_server) ThirdPartyUserServlet(hs).register(http_server) ThirdPartyLocationServlet(hs).register(http_server) diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 9f0625a822..a45ee9483e 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -45,7 +45,14 @@ class DownloadResource(Resource): @request_handler() @defer.inlineCallbacks def _async_render_GET(self, request): - request.setHeader("Content-Security-Policy", "sandbox") + request.setHeader( + "Content-Security-Policy", + "default-src 'none';" + " script-src 'none';" + " plugin-types application/pdf;" + " style-src 'unsafe-inline';" + " object-src 'self';" + ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: yield self._respond_local_file(request, media_id, name) diff --git a/synapse/server.py b/synapse/server.py index af3246504b..69860f3d82 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -35,10 +35,11 @@ from synapse.federation import initialize_http_replication from synapse.handlers import Handlers from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler +from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.device import DeviceHandler from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.presence import PresenceHandler -from synapse.handlers.room import RoomListHandler +from synapse.handlers.room_list import RoomListHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler from synapse.handlers.events import EventHandler, EventStreamHandler @@ -100,6 +101,7 @@ class HomeServer(object): 'application_service_api', 'application_service_scheduler', 'application_service_handler', + 'device_message_handler', 'notifier', 'distributor', 'client_resource', @@ -205,6 +207,9 @@ class HomeServer(object): def build_device_handler(self): return DeviceHandler(self) + def build_device_message_handler(self): + return DeviceMessageHandler(self) + def build_e2e_keys_handler(self): return E2eKeysHandler(self) diff --git a/synapse/state.py b/synapse/state.py index cd792afed1..b4eca0e5d5 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -26,6 +26,7 @@ from synapse.events.snapshot import EventContext from synapse.util.async import Linearizer from collections import namedtuple +from frozendict import frozendict import logging import hashlib @@ -55,12 +56,15 @@ def _gen_state_id(): class _StateCacheEntry(object): - __slots__ = ["state", "state_group", "state_id"] + __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] - def __init__(self, state, state_group): - self.state = state + def __init__(self, state, state_group, prev_group=None, delta_ids=None): + self.state = frozendict(state) self.state_group = state_group + self.prev_group = prev_group + self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None + # The `state_id` is a unique ID we generate that can be used as ID for # this collection of state. Usually this would be the same as the # state group, but on worker instances we can't generate a new state @@ -153,8 +157,9 @@ class StateHandler(object): defer.returnValue(state) @defer.inlineCallbacks - def get_current_user_in_room(self, room_id): - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + def get_current_user_in_room(self, room_id, latest_event_ids=None): + if not latest_event_ids: + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) entry = yield self.resolve_state_groups(room_id, latest_event_ids) joined_users = yield self.store.get_joined_users_from_state( room_id, entry.state_id, entry.state @@ -234,21 +239,29 @@ class StateHandler(object): context.prev_state_ids = curr_state if event.is_state(): context.state_group = self.store.get_next_state_group() - else: - if entry.state_group is None: - entry.state_group = self.store.get_next_state_group() - entry.state_id = entry.state_group - context.state_group = entry.state_group - if event.is_state(): key = (event.type, event.state_key) if key in context.prev_state_ids: replaces = context.prev_state_ids[key] event.unsigned["replaces_state"] = replaces + context.current_state_ids = dict(context.prev_state_ids) context.current_state_ids[key] = event.event_id + + context.prev_group = entry.prev_group + context.delta_ids = entry.delta_ids + if context.delta_ids is not None: + context.delta_ids = dict(context.delta_ids) + context.delta_ids[key] = event.event_id else: + if entry.state_group is None: + entry.state_group = self.store.get_next_state_group() + entry.state_id = entry.state_group + + context.state_group = entry.state_group context.current_state_ids = context.prev_state_ids + context.prev_group = entry.prev_group + context.delta_ids = entry.delta_ids context.prev_state_events = [] defer.returnValue(context) @@ -283,6 +296,8 @@ class StateHandler(object): defer.returnValue(_StateCacheEntry( state=state_list, state_group=name, + prev_group=name, + delta_ids={}, )) with (yield self.resolve_linearizer.queue(group_names)): @@ -340,9 +355,24 @@ class StateHandler(object): if hasattr(self.store, "get_next_state_group"): state_group = self.store.get_next_state_group() + prev_group = None + delta_ids = None + for old_group, old_ids in state_groups_ids.items(): + if not set(new_state.iterkeys()) - set(old_ids.iterkeys()): + n_delta_ids = { + k: v + for k, v in new_state.items() + if old_ids.get(k) != v + } + if not delta_ids or len(n_delta_ids) < len(delta_ids): + prev_group = old_group + delta_ids = n_delta_ids + cache = _StateCacheEntry( state=new_state, state_group=state_group, + prev_group=prev_group, + delta_ids=delta_ids, ) if self._state_cache is not None: diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 6c32773f25..9996f195a0 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -111,7 +111,10 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "presence_stream", "stream_id" ) self._device_inbox_id_gen = StreamIdGenerator( - db_conn, "device_inbox", "stream_id" + db_conn, "device_max_stream_id", "stream_id" + ) + self._public_room_id_gen = StreamIdGenerator( + db_conn, "public_room_list_stream", "stream_id" ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") @@ -182,6 +185,30 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=push_rules_prefill, ) + max_device_inbox_id = self._device_inbox_id_gen.get_current_token() + device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( + db_conn, "device_inbox", + entity_column="user_id", + stream_column="stream_id", + max_value=max_device_inbox_id + ) + self._device_inbox_stream_cache = StreamChangeCache( + "DeviceInboxStreamChangeCache", min_device_inbox_id, + prefilled_cache=device_inbox_prefill, + ) + # The federation outbox and the local device inbox uses the same + # stream_id generator. + device_outbox_prefill, min_device_outbox_id = self._get_cache_dict( + db_conn, "device_federation_outbox", + entity_column="destination", + stream_column="stream_id", + max_value=max_device_inbox_id, + ) + self._device_federation_outbox_stream_cache = StreamChangeCache( + "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, + prefilled_cache=device_outbox_prefill, + ) + cur = LoggingTransaction( db_conn.cursor(), name="_find_stream_orderings_for_times_txn", @@ -195,6 +222,8 @@ class DataStore(RoomMemberStore, RoomStore, self._find_stream_orderings_for_times, 60 * 60 * 1000 ) + self._stream_order_on_start = self.get_room_max_stream_ordering() + super(DataStore, self).__init__(hs) def take_presence_startup_info(self): diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 30d0e4c5dc..94b2bcc54a 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -133,10 +133,12 @@ class BackgroundUpdateStore(SQLBaseStore): updates = yield self._simple_select_list( "background_updates", keyvalues=None, - retcols=("update_name",), + retcols=("update_name", "depends_on"), ) + in_flight = set(update["update_name"] for update in updates) for update in updates: - self._background_update_queue.append(update['update_name']) + if update["depends_on"] not in in_flight: + self._background_update_queue.append(update['update_name']) if not self._background_update_queue: # no work left to do @@ -217,7 +219,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_handlers[update_name] = update_handler def register_background_index_update(self, update_name, index_name, - table, columns): + table, columns, where_clause=None): """Helper for store classes to do a background index addition To use: @@ -241,14 +243,20 @@ class BackgroundUpdateStore(SQLBaseStore): conc = True else: conc = False - - sql = "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" \ - % { - "conc": "CONCURRENTLY" if conc else "", - "name": index_name, - "table": table, - "columns": ", ".join(columns), - } + # We don't use partial indices on SQLite as it wasn't introduced + # until 3.8, and wheezy has 3.7 + where_clause = None + + sql = ( + "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" + " %(where_clause)s" + ) % { + "conc": "CONCURRENTLY" if conc else "", + "name": index_name, + "table": table, + "columns": ", ".join(columns), + "where_clause": "WHERE " + where_clause if where_clause else "" + } def create_index_concurrently(conn): conn.rollback() diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 68116b0394..f640e73714 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -27,63 +27,172 @@ logger = logging.getLogger(__name__) class DeviceInboxStore(SQLBaseStore): @defer.inlineCallbacks - def add_messages_to_device_inbox(self, messages_by_user_then_device): - """ + def add_messages_to_device_inbox(self, local_messages_by_user_then_device, + remote_messages_by_destination): + """Used to send messages from this server. + Args: - messages_by_user_and_device(dict): + sender_user_id(str): The ID of the user sending these messages. + local_messages_by_user_and_device(dict): Dictionary of user_id to device_id to message. + remote_messages_by_destination(dict): + Dictionary of destination server_name to the EDU JSON to send. Returns: A deferred stream_id that resolves when the messages have been inserted. """ - def select_devices_txn(txn, user_id, devices): - if not devices: - return [] - sql = ( - "SELECT user_id, device_id FROM devices" - " WHERE user_id = ? AND device_id IN (" - + ",".join("?" * len(devices)) - + ")" + def add_messages_txn(txn, now_ms, stream_id): + # Add the local messages directly to the local inbox. + self._add_messages_to_local_device_inbox_txn( + txn, stream_id, local_messages_by_user_then_device ) - # TODO: Maybe this needs to be done in batches if there are - # too many local devices for a given user. - args = [user_id] + devices - txn.execute(sql, args) - return [tuple(row) for row in txn.fetchall()] - - def add_messages_to_device_inbox_txn(txn, stream_id): - local_users_and_devices = set() - for user_id, messages_by_device in messages_by_user_then_device.items(): - local_users_and_devices.update( - select_devices_txn(txn, user_id, messages_by_device.keys()) - ) + # Add the remote messages to the federation outbox. + # We'll send them to a remote server when we next send a + # federation transaction to that destination. sql = ( - "INSERT INTO device_inbox" - " (user_id, device_id, stream_id, message_json)" + "INSERT INTO device_federation_outbox" + " (destination, stream_id, queued_ts, messages_json)" " VALUES (?,?,?,?)" ) rows = [] - for user_id, messages_by_device in messages_by_user_then_device.items(): - for device_id, message in messages_by_device.items(): - message_json = ujson.dumps(message) - # Only insert into the local inbox if the device exists on - # this server - if (user_id, device_id) in local_users_and_devices: - rows.append((user_id, device_id, stream_id, message_json)) - + for destination, edu in remote_messages_by_destination.items(): + edu_json = ujson.dumps(edu) + rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) with self._device_inbox_id_gen.get_next() as stream_id: + now_ms = self.clock.time_msec() yield self.runInteraction( "add_messages_to_device_inbox", - add_messages_to_device_inbox_txn, - stream_id + add_messages_txn, + now_ms, + stream_id, ) + for user_id in local_messages_by_user_then_device.keys(): + self._device_inbox_stream_cache.entity_has_changed( + user_id, stream_id + ) + for destination in remote_messages_by_destination.keys(): + self._device_federation_outbox_stream_cache.entity_has_changed( + destination, stream_id + ) defer.returnValue(self._device_inbox_id_gen.get_current_token()) + @defer.inlineCallbacks + def add_messages_from_remote_to_device_inbox( + self, origin, message_id, local_messages_by_user_then_device + ): + def add_messages_txn(txn, now_ms, stream_id): + # Check if we've already inserted a matching message_id for that + # origin. This can happen if the origin doesn't receive our + # acknowledgement from the first time we received the message. + already_inserted = self._simple_select_one_txn( + txn, table="device_federation_inbox", + keyvalues={"origin": origin, "message_id": message_id}, + retcols=("message_id",), + allow_none=True, + ) + if already_inserted is not None: + return + + # Add an entry for this message_id so that we know we've processed + # it. + self._simple_insert_txn( + txn, table="device_federation_inbox", + values={ + "origin": origin, + "message_id": message_id, + "received_ts": now_ms, + }, + ) + + # Add the messages to the approriate local device inboxes so that + # they'll be sent to the devices when they next sync. + self._add_messages_to_local_device_inbox_txn( + txn, stream_id, local_messages_by_user_then_device + ) + + with self._device_inbox_id_gen.get_next() as stream_id: + now_ms = self.clock.time_msec() + yield self.runInteraction( + "add_messages_from_remote_to_device_inbox", + add_messages_txn, + now_ms, + stream_id, + ) + for user_id in local_messages_by_user_then_device.keys(): + self._device_inbox_stream_cache.entity_has_changed( + user_id, stream_id + ) + + defer.returnValue(stream_id) + + def _add_messages_to_local_device_inbox_txn(self, txn, stream_id, + messages_by_user_then_device): + sql = ( + "UPDATE device_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(sql, (stream_id, stream_id)) + + local_by_user_then_device = {} + for user_id, messages_by_device in messages_by_user_then_device.items(): + messages_json_for_user = {} + devices = messages_by_device.keys() + if len(devices) == 1 and devices[0] == "*": + # Handle wildcard device_ids. + sql = ( + "SELECT device_id FROM devices" + " WHERE user_id = ?" + ) + txn.execute(sql, (user_id,)) + message_json = ujson.dumps(messages_by_device["*"]) + for row in txn.fetchall(): + # Add the message for all devices for this user on this + # server. + device = row[0] + messages_json_for_user[device] = message_json + else: + if not devices: + continue + sql = ( + "SELECT device_id FROM devices" + " WHERE user_id = ? AND device_id IN (" + + ",".join("?" * len(devices)) + + ")" + ) + # TODO: Maybe this needs to be done in batches if there are + # too many local devices for a given user. + txn.execute(sql, [user_id] + devices) + for row in txn.fetchall(): + # Only insert into the local inbox if the device exists on + # this server + device = row[0] + message_json = ujson.dumps(messages_by_device[device]) + messages_json_for_user[device] = message_json + + if messages_json_for_user: + local_by_user_then_device[user_id] = messages_json_for_user + + if not local_by_user_then_device: + return + + sql = ( + "INSERT INTO device_inbox" + " (user_id, device_id, stream_id, message_json)" + " VALUES (?,?,?,?)" + ) + rows = [] + for user_id, messages_by_device in local_by_user_then_device.items(): + for device_id, message_json in messages_by_device.items(): + rows.append((user_id, device_id, stream_id, message_json)) + + txn.executemany(sql, rows) + def get_new_messages_for_device( self, user_id, device_id, last_stream_id, current_stream_id, limit=100 ): @@ -97,6 +206,12 @@ class DeviceInboxStore(SQLBaseStore): Deferred ([dict], int): List of messages for the device and where in the stream the messages got to. """ + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_stream_id + ) + if not has_changed: + return defer.succeed(([], current_stream_id)) + def get_new_messages_for_device_txn(txn): sql = ( "SELECT stream_id, message_json FROM device_inbox" @@ -182,3 +297,71 @@ class DeviceInboxStore(SQLBaseStore): def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() + + def get_new_device_msgs_for_remote( + self, destination, last_stream_id, current_stream_id, limit=100 + ): + """ + Args: + destination(str): The name of the remote server. + last_stream_id(int): The last position of the device message stream + that the server sent up to. + current_stream_id(int): The current position of the device + message stream. + Returns: + Deferred ([dict], int): List of messages for the device and where + in the stream the messages got to. + """ + + has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( + destination, last_stream_id + ) + if not has_changed or last_stream_id == current_stream_id: + return defer.succeed(([], current_stream_id)) + + def get_new_messages_for_remote_destination_txn(txn): + sql = ( + "SELECT stream_id, messages_json FROM device_federation_outbox" + " WHERE destination = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, ( + destination, last_stream_id, current_stream_id, limit + )) + messages = [] + for row in txn.fetchall(): + stream_pos = row[0] + messages.append(ujson.loads(row[1])) + if len(messages) < limit: + stream_pos = current_stream_id + return (messages, stream_pos) + + return self.runInteraction( + "get_new_device_msgs_for_remote", + get_new_messages_for_remote_destination_txn, + ) + + def delete_device_msgs_for_remote(self, destination, up_to_stream_id): + """Used to delete messages when the remote destination acknowledges + their receipt. + + Args: + destination(str): The destination server_name + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves when the messages have been deleted. + """ + def delete_messages_for_remote_destination_txn(txn): + sql = ( + "DELETE FROM device_federation_outbox" + " WHERE destination = ?" + " AND stream_id <= ?" + ) + txn.execute(sql, (destination, up_to_stream_id)) + + return self.runInteraction( + "delete_device_msgs_for_remote", + delete_messages_for_remote_destination_txn + ) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index afd6530cab..17920d4480 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -54,8 +54,12 @@ class DeviceStore(SQLBaseStore): or_ignore=ignore_if_known, ) except Exception as e: - logger.error("store_device with device_id=%s failed: %s", - device_id, e) + logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" + " display_name=%s(%r) failed: %s", + type(device_id).__name__, device_id, + type(user_id).__name__, user_id, + type(initial_device_display_name).__name__, + initial_device_display_name, e) raise StoreError(500, "Problem storing device.") def get_device(self, user_id, device_id): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 0827946207..3d62451de9 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -16,6 +16,7 @@ from twisted.internet import defer from ._base import SQLBaseStore +from synapse.api.errors import StoreError from synapse.util.caches.descriptors import cached from unpaddedbase64 import encode_base64 @@ -36,6 +37,13 @@ class EventFederationStore(SQLBaseStore): and backfilling from another server respectively. """ + def __init__(self, hs): + super(EventFederationStore, self).__init__(hs) + + hs.get_clock().looping_call( + self._delete_old_forward_extrem_cache, 60 * 60 * 1000 + ) + def get_auth_chain(self, event_ids): return self.get_auth_chain_ids(event_ids).addCallback(self._get_events) @@ -270,6 +278,37 @@ class EventFederationStore(SQLBaseStore): ] ) + # We now insert into stream_ordering_to_exterm a mapping from room_id, + # new stream_ordering to new forward extremeties in the room. + # This allows us to later efficiently look up the forward extremeties + # for a room before a given stream_ordering + max_stream_ord = max( + ev.internal_metadata.stream_ordering for ev in events + ) + new_extrem = {} + for room_id in events_by_room: + event_ids = self._simple_select_onecol_txn( + txn, + table="event_forward_extremities", + keyvalues={"room_id": room_id}, + retcol="event_id", + ) + new_extrem[room_id] = event_ids + + self._simple_insert_many_txn( + txn, + table="stream_ordering_to_exterm", + values=[ + { + "room_id": room_id, + "event_id": event_id, + "stream_ordering": max_stream_ord, + } + for room_id, extrem_evs in new_extrem.items() + for event_id in extrem_evs + ] + ) + query = ( "INSERT INTO event_backward_extremities (event_id, room_id)" " SELECT ?, ? WHERE NOT EXISTS (" @@ -305,6 +344,76 @@ class EventFederationStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate, (room_id,) ) + def get_forward_extremeties_for_room(self, room_id, stream_ordering): + # We want to make the cache more effective, so we clamp to the last + # change before the given ordering. + last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) + + # We don't always have a full stream_to_exterm_id table, e.g. after + # the upgrade that introduced it, so we make sure we never ask for a + # try and pin to a stream_ordering from before a restart + last_change = max(self._stream_order_on_start, last_change) + + if last_change > self.stream_ordering_month_ago: + stream_ordering = min(last_change, stream_ordering) + + return self._get_forward_extremeties_for_room(room_id, stream_ordering) + + @cached(max_entries=5000, num_args=2) + def _get_forward_extremeties_for_room(self, room_id, stream_ordering): + """For a given room_id and stream_ordering, return the forward + extremeties of the room at that point in "time". + + Throws a StoreError if we have since purged the index for + stream_orderings from that point. + """ + + if stream_ordering <= self.stream_ordering_month_ago: + raise StoreError(400, "stream_ordering too old") + + sql = (""" + SELECT event_id FROM stream_ordering_to_exterm + INNER JOIN ( + SELECT room_id, MAX(stream_ordering) AS stream_ordering + FROM stream_ordering_to_exterm + WHERE stream_ordering <= ? GROUP BY room_id + ) AS rms USING (room_id, stream_ordering) + WHERE room_id = ? + """) + + def get_forward_extremeties_for_room_txn(txn): + txn.execute(sql, (stream_ordering, room_id)) + rows = txn.fetchall() + return [event_id for event_id, in rows] + + return self.runInteraction( + "get_forward_extremeties_for_room", + get_forward_extremeties_for_room_txn + ) + + def _delete_old_forward_extrem_cache(self): + def _delete_old_forward_extrem_cache_txn(txn): + # Delete entries older than a month, while making sure we don't delete + # the only entries for a room. + sql = (""" + DELETE FROM stream_ordering_to_exterm + WHERE + ( + SELECT max(stream_ordering) AS stream_ordering + FROM stream_ordering_to_exterm + WHERE room_id = stream_ordering_to_exterm.room_id + ) > ? + AND stream_ordering < ? + """) + txn.execute( + sql, + (self.stream_ordering_month_ago, self.stream_ordering_month_ago,) + ) + return self.runInteraction( + "_delete_old_forward_extrem_cache", + _delete_old_forward_extrem_cache_txn + ) + def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index eb15fb751b..9cd923eb93 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -26,10 +26,19 @@ logger = logging.getLogger(__name__) class EventPushActionsStore(SQLBaseStore): + EPA_HIGHLIGHT_INDEX = "epa_highlight_index" + def __init__(self, hs): self.stream_ordering_month_ago = None super(EventPushActionsStore, self).__init__(hs) + self.register_background_index_update( + self.EPA_HIGHLIGHT_INDEX, + index_name="event_push_actions_u_highlight", + table="event_push_actions", + columns=["user_id", "stream_ordering"], + ) + def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ Args: @@ -338,20 +347,29 @@ class EventPushActionsStore(SQLBaseStore): defer.returnValue(notifs[:limit]) @defer.inlineCallbacks - def get_push_actions_for_user(self, user_id, before=None, limit=50): + def get_push_actions_for_user(self, user_id, before=None, limit=50, + only_highlight=False): def f(txn): before_clause = "" if before: - before_clause = "AND stream_ordering < ?" + before_clause = "AND epa.stream_ordering < ?" args = [user_id, before, limit] else: args = [user_id, limit] + + if only_highlight: + if len(before_clause) > 0: + before_clause += " " + before_clause += "AND epa.highlight = 1" + + # NB. This assumes event_ids are globally unique since + # it makes the query easier to index sql = ( "SELECT epa.event_id, epa.room_id," " epa.stream_ordering, epa.topological_ordering," " epa.actions, epa.profile_tag, e.received_ts" " FROM event_push_actions epa, events e" - " WHERE epa.room_id = e.room_id AND epa.event_id = e.event_id" + " WHERE epa.event_id = e.event_id" " AND epa.user_id = ? %s" " ORDER BY epa.stream_ordering DESC" " LIMIT ?" diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 1a7d4c5199..6dc46fa50f 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -189,6 +189,14 @@ class EventsStore(SQLBaseStore): self._background_reindex_fields_sender, ) + self.register_background_index_update( + "event_contains_url_index", + index_name="event_contains_url_index", + table="events", + columns=["room_id", "topological_ordering", "stream_ordering"], + where_clause="contains_url = true AND outlier = false", + ) + self._event_persist_queue = _EventPeristenceQueue() def persist_events(self, events_and_contexts, backfilled=False): @@ -497,7 +505,11 @@ class EventsStore(SQLBaseStore): # insert into the state_group, state_groups_state and # event_to_state_groups tables. - self._store_mult_state_groups_txn(txn, ((event, context),)) + try: + self._store_mult_state_groups_txn(txn, ((event, context),)) + except Exception: + logger.exception("") + raise metadata_json = encode_json( event.internal_metadata.get_dict() @@ -1543,6 +1555,9 @@ class EventsStore(SQLBaseStore): ) event_rows = txn.fetchall() + for event_id, state_key in event_rows: + txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) + # We calculate the new entries for the backward extremeties by finding # all events that point to events that are to be purged txn.execute( @@ -1582,7 +1597,66 @@ class EventsStore(SQLBaseStore): " GROUP BY state_group HAVING MAX(topological_ordering) < ?", (room_id, topological_ordering, topological_ordering) ) + state_rows = txn.fetchall() + state_groups_to_delete = [sg for sg, in state_rows] + + # Now we get all the state groups that rely on these state groups + new_state_edges = [] + chunks = [ + state_groups_to_delete[i:i + 100] + for i in xrange(0, len(state_groups_to_delete), 100) + ] + for chunk in chunks: + rows = self._simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=chunk, + retcols=["state_group"], + keyvalues={}, + ) + new_state_edges.extend(row["state_group"] for row in rows) + + # Now we turn the state groups that reference to-be-deleted state groups + # to non delta versions. + for new_state_edge in new_state_edges: + curr_state = self._get_state_groups_from_groups_txn( + txn, [new_state_edge], types=None + ) + curr_state = curr_state[new_state_edge] + + self._simple_delete_txn( + txn, + table="state_groups_state", + keyvalues={ + "state_group": new_state_edge, + } + ) + + self._simple_delete_txn( + txn, + table="state_group_edges", + keyvalues={ + "state_group": new_state_edge, + } + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": new_state_edge, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in curr_state.items() + ], + ) + txn.executemany( "DELETE FROM state_groups_state WHERE state_group = ?", state_rows diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index b94ce7bea1..7efbe51cda 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 34 +SCHEMA_VERSION = 35 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -242,7 +242,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, module = imp.load_source( module_name, absolute_path, python_file ) - logger.debug("Running script %s", relative_path) + logger.info("Running script %s", relative_path) module.run_create(cur, database_engine) if not is_empty: module.run_upgrade(cur, database_engine, config=config) @@ -253,7 +253,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, pass elif ext == ".sql": # A plain old .sql file, just read and execute it - logger.debug("Applying schema %s", relative_path) + logger.info("Applying schema %s", relative_path) executescript(cur, absolute_path) else: # Not a valid delta file. diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 8251f58670..2ef13d7403 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -48,15 +48,31 @@ class RoomStore(SQLBaseStore): StoreError if the room could not be stored. """ try: - yield self._simple_insert( - "rooms", - { - "room_id": room_id, - "creator": room_creator_user_id, - "is_public": is_public, - }, - desc="store_room", - ) + def store_room_txn(txn, next_id): + self._simple_insert_txn( + txn, + "rooms", + { + "room_id": room_id, + "creator": room_creator_user_id, + "is_public": is_public, + }, + ) + if is_public: + self._simple_insert_txn( + txn, + table="public_room_list_stream", + values={ + "stream_id": next_id, + "room_id": room_id, + "visibility": is_public, + } + ) + with self._public_room_id_gen.get_next() as next_id: + yield self.runInteraction( + "store_room_txn", + store_room_txn, next_id, + ) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -77,13 +93,45 @@ class RoomStore(SQLBaseStore): allow_none=True, ) + @defer.inlineCallbacks def set_room_is_public(self, room_id, is_public): - return self._simple_update_one( - table="rooms", - keyvalues={"room_id": room_id}, - updatevalues={"is_public": is_public}, - desc="set_room_is_public", - ) + def set_room_is_public_txn(txn, next_id): + self._simple_update_one_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"is_public": is_public}, + ) + + entries = self._simple_select_list_txn( + txn, + table="public_room_list_stream", + keyvalues={"room_id": room_id}, + retcols=("stream_id", "visibility"), + ) + + entries.sort(key=lambda r: r["stream_id"]) + + add_to_stream = True + if entries: + add_to_stream = bool(entries[-1]["visibility"]) != is_public + + if add_to_stream: + self._simple_insert_txn( + txn, + table="public_room_list_stream", + values={ + "stream_id": next_id, + "room_id": room_id, + "visibility": is_public, + } + ) + + with self._public_room_id_gen.get_next() as next_id: + yield self.runInteraction( + "set_room_is_public", + set_room_is_public_txn, next_id, + ) def get_public_room_ids(self): return self._simple_select_onecol( @@ -207,3 +255,71 @@ class RoomStore(SQLBaseStore): }, desc="add_event_report" ) + + def get_current_public_room_stream_id(self): + return self._public_room_id_gen.get_current_token() + + def get_public_room_ids_at_stream_id(self, stream_id): + return self.runInteraction( + "get_public_room_ids_at_stream_id", + self.get_public_room_ids_at_stream_id_txn, stream_id + ) + + def get_public_room_ids_at_stream_id_txn(self, txn, stream_id): + return { + rm + for rm, vis in self.get_published_at_stream_id_txn(txn, stream_id).items() + if vis + } + + def get_published_at_stream_id_txn(self, txn, stream_id): + sql = (""" + SELECT room_id, visibility FROM public_room_list_stream + INNER JOIN ( + SELECT room_id, max(stream_id) AS stream_id + FROM public_room_list_stream + WHERE stream_id <= ? + GROUP BY room_id + ) grouped USING (room_id, stream_id) + """) + + txn.execute(sql, (stream_id,)) + return dict(txn.fetchall()) + + def get_public_room_changes(self, prev_stream_id, new_stream_id): + def get_public_room_changes_txn(txn): + then_rooms = self.get_public_room_ids_at_stream_id_txn(txn, prev_stream_id) + + now_rooms_dict = self.get_published_at_stream_id_txn(txn, new_stream_id) + + now_rooms_visible = set( + rm for rm, vis in now_rooms_dict.items() if vis + ) + now_rooms_not_visible = set( + rm for rm, vis in now_rooms_dict.items() if not vis + ) + + newly_visible = now_rooms_visible - then_rooms + newly_unpublished = now_rooms_not_visible & then_rooms + + return newly_visible, newly_unpublished + + return self.runInteraction( + "get_public_room_changes", get_public_room_changes_txn + ) + + def get_all_new_public_rooms(self, prev_id, current_id, limit): + def get_all_new_public_rooms(txn): + sql = (""" + SELECT stream_id, room_id, visibility FROM public_room_list_stream + WHERE stream_id > ? AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """) + + txn.execute(sql, (prev_id, current_id, limit,)) + return txn.fetchall() + + return self.runInteraction( + "get_all_new_public_rooms", get_all_new_public_rooms + ) diff --git a/synapse/storage/schema/delta/22/receipts_index.sql b/synapse/storage/schema/delta/22/receipts_index.sql index 7bc061dff6..bfc0b3bcaa 100644 --- a/synapse/storage/schema/delta/22/receipts_index.sql +++ b/synapse/storage/schema/delta/22/receipts_index.sql @@ -13,6 +13,10 @@ * limitations under the License. */ +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id ); diff --git a/synapse/storage/schema/delta/28/events_room_stream.sql b/synapse/storage/schema/delta/28/events_room_stream.sql index 200c35e6e2..36609475f1 100644 --- a/synapse/storage/schema/delta/28/events_room_stream.sql +++ b/synapse/storage/schema/delta/28/events_room_stream.sql @@ -13,4 +13,8 @@ * limitations under the License. */ +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX events_room_stream on events(room_id, stream_ordering); diff --git a/synapse/storage/schema/delta/28/public_roms_index.sql b/synapse/storage/schema/delta/28/public_roms_index.sql index ba62a974a4..6c1fd68c5b 100644 --- a/synapse/storage/schema/delta/28/public_roms_index.sql +++ b/synapse/storage/schema/delta/28/public_roms_index.sql @@ -13,4 +13,8 @@ * limitations under the License. */ +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX public_room_index on rooms(is_public); diff --git a/synapse/storage/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/schema/delta/28/receipts_user_id_index.sql index 452a1b3c6c..cb84c69baa 100644 --- a/synapse/storage/schema/delta/28/receipts_user_id_index.sql +++ b/synapse/storage/schema/delta/28/receipts_user_id_index.sql @@ -13,6 +13,10 @@ * limitations under the License. */ +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id ); diff --git a/synapse/storage/schema/delta/29/push_actions.sql b/synapse/storage/schema/delta/29/push_actions.sql index 7e7b09820a..84b21cf813 100644 --- a/synapse/storage/schema/delta/29/push_actions.sql +++ b/synapse/storage/schema/delta/29/push_actions.sql @@ -26,6 +26,10 @@ UPDATE event_push_actions SET stream_ordering = ( UPDATE event_push_actions SET notif = 1, highlight = 0; +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX event_push_actions_rm_tokens on event_push_actions( user_id, room_id, topological_ordering, stream_ordering ); diff --git a/synapse/storage/schema/delta/31/pushers_index.sql b/synapse/storage/schema/delta/31/pushers_index.sql index 9027bccc69..a82add88fd 100644 --- a/synapse/storage/schema/delta/31/pushers_index.sql +++ b/synapse/storage/schema/delta/31/pushers_index.sql @@ -13,6 +13,10 @@ * limitations under the License. */ +/** Using CREATE INDEX directly is deprecated in favour of using background + * update see synapse/storage/schema/delta/33/access_tokens_device_index.sql + * and synapse/storage/registration.py for an example using + * "access_tokens_device_index" **/ CREATE INDEX event_push_actions_stream_ordering on event_push_actions( stream_ordering, user_id ); diff --git a/synapse/storage/schema/delta/35/add_state_index.sql b/synapse/storage/schema/delta/35/add_state_index.sql new file mode 100644 index 0000000000..0fce26345b --- /dev/null +++ b/synapse/storage/schema/delta/35/add_state_index.sql @@ -0,0 +1,20 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +ALTER TABLE background_updates ADD COLUMN depends_on TEXT; + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication'); diff --git a/synapse/storage/schema/delta/35/contains_url.sql b/synapse/storage/schema/delta/35/contains_url.sql new file mode 100644 index 0000000000..6cd123027b --- /dev/null +++ b/synapse/storage/schema/delta/35/contains_url.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + INSERT into background_updates (update_name, progress_json) + VALUES ('event_contains_url_index', '{}'); diff --git a/synapse/storage/schema/delta/35/device_outbox.sql b/synapse/storage/schema/delta/35/device_outbox.sql new file mode 100644 index 0000000000..17e6c43105 --- /dev/null +++ b/synapse/storage/schema/delta/35/device_outbox.sql @@ -0,0 +1,39 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +DROP TABLE IF EXISTS device_federation_outbox; +CREATE TABLE device_federation_outbox ( + destination TEXT NOT NULL, + stream_id BIGINT NOT NULL, + queued_ts BIGINT NOT NULL, + messages_json TEXT NOT NULL +); + + +DROP INDEX IF EXISTS device_federation_outbox_destination_id; +CREATE INDEX device_federation_outbox_destination_id + ON device_federation_outbox(destination, stream_id); + + +DROP TABLE IF EXISTS device_federation_inbox; +CREATE TABLE device_federation_inbox ( + origin TEXT NOT NULL, + message_id TEXT NOT NULL, + received_ts BIGINT NOT NULL +); + +DROP INDEX IF EXISTS device_federation_inbox_sender_id; +CREATE INDEX device_federation_inbox_sender_id + ON device_federation_inbox(origin, message_id); diff --git a/synapse/storage/schema/delta/35/device_stream_id.sql b/synapse/storage/schema/delta/35/device_stream_id.sql new file mode 100644 index 0000000000..7ab7d942e2 --- /dev/null +++ b/synapse/storage/schema/delta/35/device_stream_id.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_max_stream_id ( + stream_id BIGINT NOT NULL +); + +INSERT INTO device_max_stream_id (stream_id) + SELECT COALESCE(MAX(stream_id), 0) FROM device_inbox; diff --git a/synapse/storage/schema/delta/35/event_push_actions_index.sql b/synapse/storage/schema/delta/35/event_push_actions_index.sql new file mode 100644 index 0000000000..2e836d8e9c --- /dev/null +++ b/synapse/storage/schema/delta/35/event_push_actions_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + INSERT into background_updates (update_name, progress_json) + VALUES ('epa_highlight_index', '{}'); diff --git a/synapse/storage/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/schema/delta/35/public_room_list_change_stream.sql new file mode 100644 index 0000000000..dd2bf2e28a --- /dev/null +++ b/synapse/storage/schema/delta/35/public_room_list_change_stream.sql @@ -0,0 +1,33 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE public_room_list_stream ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + visibility BOOLEAN NOT NULL +); + +INSERT INTO public_room_list_stream (stream_id, room_id, visibility) + SELECT 1, room_id, is_public FROM rooms + WHERE is_public = CAST(1 AS BOOLEAN); + +CREATE INDEX public_room_list_stream_idx on public_room_list_stream( + stream_id +); + +CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( + room_id, stream_id +); diff --git a/synapse/storage/schema/delta/35/state.sql b/synapse/storage/schema/delta/35/state.sql new file mode 100644 index 0000000000..0f1fa68a89 --- /dev/null +++ b/synapse/storage/schema/delta/35/state.sql @@ -0,0 +1,22 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE state_group_edges( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); + +CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); diff --git a/synapse/storage/schema/delta/35/state_dedupe.sql b/synapse/storage/schema/delta/35/state_dedupe.sql new file mode 100644 index 0000000000..97e5067ef4 --- /dev/null +++ b/synapse/storage/schema/delta/35/state_dedupe.sql @@ -0,0 +1,17 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('state_group_state_deduplication', '{}'); diff --git a/synapse/storage/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/schema/delta/35/stream_order_to_extrem.sql new file mode 100644 index 0000000000..2b945d8a57 --- /dev/null +++ b/synapse/storage/schema/delta/35/stream_order_to_extrem.sql @@ -0,0 +1,37 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE stream_ordering_to_exterm ( + stream_ordering BIGINT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id) + SELECT stream_ordering, room_id, event_id FROM event_forward_extremities + INNER JOIN ( + SELECT room_id, max(stream_ordering) as stream_ordering FROM events + INNER JOIN event_forward_extremities USING (room_id, event_id) + GROUP BY room_id + ) AS rms USING (room_id); + +CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( + stream_ordering +); + +CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( + room_id, stream_ordering +); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index ec551b0b4f..7eb342674c 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -16,6 +16,7 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches import intern_string +from synapse.storage.engines import PostgresEngine from twisted.internet import defer @@ -24,6 +25,9 @@ import logging logger = logging.getLogger(__name__) +MAX_STATE_DELTA_HOPS = 100 + + class StateStore(SQLBaseStore): """ Keeps track of the state at a given event. @@ -43,6 +47,20 @@ class StateStore(SQLBaseStore): * `state_groups_state`: Maps state group to state events. """ + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + + def __init__(self, hs): + super(StateStore, self).__init__(hs) + self.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, + self._background_index_state, + ) + @defer.inlineCallbacks def get_state_groups_ids(self, room_id, event_ids): if not event_ids: @@ -103,11 +121,8 @@ class StateStore(SQLBaseStore): state_groups[event.event_id] = context.state_group if self._have_persisted_state_group_txn(txn, context.state_group): - logger.info("Already persisted state_group: %r", context.state_group) continue - state_event_ids = dict(context.current_state_ids) - self._simple_insert_txn( txn, table="state_groups", @@ -118,20 +133,51 @@ class StateStore(SQLBaseStore): }, ) - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if context.prev_group: + potential_hops = self._count_state_group_hops_txn( + txn, context.prev_group + ) + if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self._simple_insert_txn( + txn, + table="state_group_edges", + values={ "state_group": context.state_group, - "room_id": event.room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in state_event_ids.items() - ], - ) + "prev_state_group": context.prev_group, + }, + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": context.state_group, + "room_id": event.room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in context.delta_ids.items() + ], + ) + else: + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": context.state_group, + "room_id": event.room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in context.current_state_ids.items() + ], + ) self._simple_insert_many_txn( txn, @@ -145,6 +191,47 @@ class StateStore(SQLBaseStore): ], ) + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = (""" + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """) + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): if event_type and state_key is not None: @@ -206,48 +293,124 @@ class StateStore(SQLBaseStore): def _get_state_groups_from_groups(self, groups, types): """Returns dictionary state_group -> (dict of (type, state_key) -> event id) """ - def f(txn, groups): - if types is not None: - where_clause = "AND (%s)" % ( - " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), - ) - else: - where_clause = "" - - sql = ( - "SELECT state_group, event_id, type, state_key" - " FROM state_groups_state WHERE" - " state_group IN (%s) %s" % ( - ",".join("?" for _ in groups), - where_clause, - ) - ) - - args = list(groups) - if types is not None: - args.extend([i for typ in types for i in typ]) - - txn.execute(sql, args) - rows = self.cursor_to_dict(txn) - - results = {group: {} for group in groups} - for row in rows: - key = (row["type"], row["state_key"]) - results[row["state_group"]][key] = row["event_id"] - return results - results = {} chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)] for chunk in chunks: res = yield self.runInteraction( "_get_state_groups_from_groups", - f, chunk + self._get_state_groups_from_groups_txn, chunk, types, ) results.update(res) defer.returnValue(results) + def _get_state_groups_from_groups_txn(self, txn, groups, types=None): + results = {group: {} for group in groups} + if isinstance(self.database_engine, PostgresEngine): + # Temporarily disable sequential scans in this transaction. This is + # a temporary hack until we can add the right indices in + txn.execute("SET LOCAL enable_seqscan=off") + + # The below query walks the state_group tree so that the "state" + # table includes all state_groups in the tree. It then joins + # against `state_groups_state` to fetch the latest state. + # It assumes that previous state groups are always numerically + # lesser. + # The PARTITION is used to get the event_id in the greatest state + # group for the given type, state_key. + # This may return multiple rows per (type, state_key), but last_value + # should be the same. + sql = (""" + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT type, state_key, last_value(event_id) OVER ( + PARTITION BY type, state_key ORDER BY state_group ASC + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS event_id FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM state + ) + %s + """) + + # Turns out that postgres doesn't like doing a list of OR's and + # is about 1000x slower, so we just issue a query for each specific + # type seperately. + if types: + clause_to_args = [ + ( + "AND type = ? AND state_key = ?", + (etype, state_key) + ) + for etype, state_key in types + ] + else: + # If types is None we fetch all the state, and so just use an + # empty where clause with no extra args. + clause_to_args = [("", [])] + + for where_clause, where_args in clause_to_args: + for group in groups: + args = [group] + args.extend(where_args) + + txn.execute(sql % (where_clause,), args) + rows = self.cursor_to_dict(txn) + for row in rows: + key = (row["type"], row["state_key"]) + results[group][key] = row["event_id"] + else: + if types is not None: + where_clause = "AND (%s)" % ( + " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), + ) + else: + where_clause = "" + + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + for group in groups: + group_tree = [group] + next_group = group + + while next_group: + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + group_tree.append(next_group) + + sql = (""" + SELECT type, state_key, event_id FROM state_groups_state + INNER JOIN ( + SELECT type, state_key, max(state_group) as state_group + FROM state_groups_state + WHERE state_group IN (%s) %s + GROUP BY type, state_key + ) USING (type, state_key, state_group); + """) % (",".join("?" for _ in group_tree), where_clause,) + + args = list(group_tree) + if types is not None: + args.extend([i for typ in types for i in typ]) + + txn.execute(sql, args) + rows = self.cursor_to_dict(txn) + for row in rows: + key = (row["type"], row["state_key"]) + results[group][key] = row["event_id"] + + return results + @defer.inlineCallbacks def get_state_for_events(self, event_ids, types): """Given a list of event_ids and type tuples, return a list of state @@ -504,32 +667,184 @@ class StateStore(SQLBaseStore): defer.returnValue(results) - def get_all_new_state_groups(self, last_id, current_id, limit): - def get_all_new_state_groups_txn(txn): - sql = ( - "SELECT id, room_id, event_id FROM state_groups" - " WHERE ? < id AND id <= ? ORDER BY id LIMIT ?" + def get_next_state_group(self): + return self._state_groups_id_gen.get_next() + + @defer.inlineCallbacks + def _background_deduplicate_state(self, progress, batch_size): + """This background update will slowly deduplicate state by reencoding + them as deltas. + """ + last_state_group = progress.get("last_state_group", 0) + rows_inserted = progress.get("rows_inserted", 0) + max_group = progress.get("max_group", None) + + BATCH_SIZE_SCALE_FACTOR = 100 + + batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) + + if max_group is None: + rows = yield self._execute( + "_background_deduplicate_state", None, + "SELECT coalesce(max(id), 0) FROM state_groups", ) - txn.execute(sql, (last_id, current_id, limit)) - groups = txn.fetchall() + max_group = rows[0][0] + + def reindex_txn(txn): + new_last_state_group = last_state_group + for count in xrange(batch_size): + txn.execute( + "SELECT id, room_id FROM state_groups" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC" + " LIMIT 1", + (new_last_state_group, max_group,) + ) + row = txn.fetchone() + if row: + state_group, room_id = row - if not groups: - return ([], []) + if not row or not state_group: + return True, count - lower_bound = groups[0][0] - upper_bound = groups[-1][0] - sql = ( - "SELECT state_group, type, state_key, event_id" - " FROM state_groups_state" - " WHERE ? <= state_group AND state_group <= ?" + txn.execute( + "SELECT state_group FROM state_group_edges" + " WHERE state_group = ?", + (state_group,) + ) + + # If we reach a point where we've already started inserting + # edges we should stop. + if txn.fetchall(): + return True, count + + txn.execute( + "SELECT coalesce(max(id), 0) FROM state_groups" + " WHERE id < ? AND room_id = ?", + (state_group, room_id,) + ) + prev_group, = txn.fetchone() + new_last_state_group = state_group + + if prev_group: + potential_hops = self._count_state_group_hops_txn( + txn, prev_group + ) + if potential_hops >= MAX_STATE_DELTA_HOPS: + # We want to ensure chains are at most this long,# + # otherwise read performance degrades. + continue + + prev_state = self._get_state_groups_from_groups_txn( + txn, [prev_group], types=None + ) + prev_state = prev_state[prev_group] + + curr_state = self._get_state_groups_from_groups_txn( + txn, [state_group], types=None + ) + curr_state = curr_state[state_group] + + if not set(prev_state.keys()) - set(curr_state.keys()): + # We can only do a delta if the current has a strict super set + # of keys + + delta_state = { + key: value for key, value in curr_state.items() + if prev_state.get(key, None) != value + } + + self._simple_delete_txn( + txn, + table="state_group_edges", + keyvalues={ + "state_group": state_group, + } + ) + + self._simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + } + ) + + self._simple_delete_txn( + txn, + table="state_groups_state", + keyvalues={ + "state_group": state_group, + } + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in delta_state.items() + ], + ) + + progress = { + "last_state_group": state_group, + "rows_inserted": rows_inserted + batch_size, + "max_group": max_group, + } + + self._background_update_progress_txn( + txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress ) - txn.execute(sql, (lower_bound, upper_bound)) - state_group_state = txn.fetchall() - return (groups, state_group_state) - return self.runInteraction( - "get_all_new_state_groups", get_all_new_state_groups_txn + return False, batch_size + + finished, result = yield self.runInteraction( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn ) - def get_next_state_group(self): - return self._state_groups_id_gen.get_next() + if finished: + yield self._end_background_update(self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME) + + defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR) + + @defer.inlineCallbacks + def _background_index_state(self, progress, batch_size): + def reindex_txn(conn): + conn.rollback() + if isinstance(self.database_engine, PostgresEngine): + # postgres insists on autocommit for the index + conn.set_session(autocommit=True) + try: + txn = conn.cursor() + txn.execute( + "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute( + "DROP INDEX IF EXISTS state_groups_state_id" + ) + finally: + conn.set_session(autocommit=False) + else: + txn = conn.cursor() + txn.execute( + "CREATE INDEX state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute( + "DROP INDEX IF EXISTS state_groups_state_id" + ) + + yield self.runWithConnection(reindex_txn) + + yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + + defer.returnValue(1) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 0577a0525b..07ea969d4d 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -531,6 +531,9 @@ class StreamStore(SQLBaseStore): ) defer.returnValue("t%d-%d" % (topo, token)) + def get_room_max_stream_ordering(self): + return self._stream_id_gen.get_current_token() + def get_stream_token_for_event(self, event_id): """The stream token for an event Args: diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 3c051dabc4..b72bb0ff02 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -121,3 +121,9 @@ class StreamChangeCache(object): k, r = self._cache.popitem() self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) self._entity_to_key.pop(r, None) + + def get_max_pos_of_last_change(self, entity): + """Returns an upper bound of the stream id of the last change to an + entity. + """ + return self._entity_to_key.get(entity, self._earliest_known_stream_pos) |