diff options
Diffstat (limited to 'synapse')
137 files changed, 2609 insertions, 1294 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index ef8853bd24..f31cb9a3cb 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.26.0" +__version__ = "0.28.1" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index ac0a3655a5..f17fda6315 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -204,8 +204,8 @@ class Auth(object): ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( - "User-Agent", - default=[""] + b"User-Agent", + default=[b""] )[0] if user and access_token and ip_addr: self.store.insert_client_ip( @@ -672,7 +672,7 @@ def has_access_token(request): bool: False if no access_token was given, True otherwise. """ query_params = request.args.get("access_token") - auth_headers = request.requestHeaders.getRawHeaders("Authorization") + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") return bool(query_params) or bool(auth_headers) @@ -692,8 +692,8 @@ def get_access_token_from_request(request, token_not_found_http_status=401): AuthError: If there isn't an access_token in the request. """ - auth_headers = request.requestHeaders.getRawHeaders("Authorization") - query_params = request.args.get("access_token") + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + query_params = request.args.get(b"access_token") if auth_headers: # Try the get the access_token from a "Authorization: Bearer" # header diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 489efb7f86..5baba43966 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -16,6 +16,9 @@ """Contains constants from the specification.""" +# the "depth" field on events is limited to 2**63 - 1 +MAX_DEPTH = 2**63 - 1 + class Membership(object): diff --git a/synapse/api/errors.py b/synapse/api/errors.py index aa15f73f36..a9ff5576f3 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -15,9 +15,11 @@ """Contains exceptions and error codes.""" -import json import logging +import simplejson as json +from six import iteritems + logger = logging.getLogger(__name__) @@ -296,7 +298,7 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs): A dict representing the error response JSON. """ err = {"error": msg, "errcode": code} - for key, value in kwargs.iteritems(): + for key, value in iteritems(kwargs): err[key] = value return err diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 83206348e5..db43219d24 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -17,7 +17,7 @@ from synapse.storage.presence import UserPresenceState from synapse.types import UserID, RoomID from twisted.internet import defer -import ujson as json +import simplejson as json import jsonschema from jsonschema import FormatChecker diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index c6fe4516d1..58f2c9d68c 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -32,11 +32,11 @@ from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, preserve_fn +from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string -from twisted.internet import reactor -from twisted.web.resource import Resource +from twisted.internet import reactor, defer +from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.appservice") @@ -64,7 +64,7 @@ class AppserviceServer(HomeServer): if name == "metrics": resources[METRICS_PREFIX] = MetricsResource(self) - root_resource = create_resource_tree(resources, Resource()) + root_resource = create_resource_tree(resources, NoResource()) _base.listen_tcp( bind_addresses, @@ -112,9 +112,14 @@ class ASReplicationHandler(ReplicationClientHandler): if stream_name == "events": max_stream_id = self.store.get_room_max_stream_ordering() - preserve_fn( - self.appservice_handler.notify_interested_services - )(max_stream_id) + run_in_background(self._notify_app_services, max_stream_id) + + @defer.inlineCallbacks + def _notify_app_services(self, room_stream_id): + try: + yield self.appservice_handler.notify_interested_services(room_stream_id) + except Exception: + logger.exception("Error notifying application services of event") def start(config_options): diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 0a8ce9bc66..267d34c881 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -44,7 +44,7 @@ from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string from twisted.internet import reactor -from twisted.web.resource import Resource +from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.client_reader") @@ -88,7 +88,7 @@ class ClientReaderServer(HomeServer): "/_matrix/client/api/v1": resource, }) - root_resource = create_resource_tree(resources, Resource()) + root_resource = create_resource_tree(resources, NoResource()) _base.listen_tcp( bind_addresses, diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index 172e989b54..b915d12d53 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -52,7 +52,7 @@ from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string from twisted.internet import reactor -from twisted.web.resource import Resource +from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.event_creator") @@ -104,7 +104,7 @@ class EventCreatorServer(HomeServer): "/_matrix/client/api/v1": resource, }) - root_resource = create_resource_tree(resources, Resource()) + root_resource = create_resource_tree(resources, NoResource()) _base.listen_tcp( bind_addresses, diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 20d157911b..c1dc66dd17 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -41,7 +41,7 @@ from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string from twisted.internet import reactor -from twisted.web.resource import Resource +from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.federation_reader") @@ -77,7 +77,7 @@ class FederationReaderServer(HomeServer): FEDERATION_PREFIX: TransportLayerServer(self), }) - root_resource = create_resource_tree(resources, Resource()) + root_resource = create_resource_tree(resources, NoResource()) _base.listen_tcp( bind_addresses, diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index f760826d27..a08af83a4c 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -38,11 +38,11 @@ from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.util.async import Linearizer from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, preserve_fn +from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string from twisted.internet import defer, reactor -from twisted.web.resource import Resource +from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.federation_sender") @@ -91,7 +91,7 @@ class FederationSenderServer(HomeServer): if name == "metrics": resources[METRICS_PREFIX] = MetricsResource(self) - root_resource = create_resource_tree(resources, Resource()) + root_resource = create_resource_tree(resources, NoResource()) _base.listen_tcp( bind_addresses, @@ -229,7 +229,7 @@ class FederationSenderHandler(object): # presence, typing, etc. if stream_name == "federation": send_queue.process_rows_for_federation(self.federation_sender, rows) - preserve_fn(self.update_token)(token) + run_in_background(self.update_token, token) # We also need to poke the federation sender when new events happen elif stream_name == "events": @@ -237,19 +237,22 @@ class FederationSenderHandler(object): @defer.inlineCallbacks def update_token(self, token): - self.federation_position = token - - # We linearize here to ensure we don't have races updating the token - with (yield self._fed_position_linearizer.queue(None)): - if self._last_ack < self.federation_position: - yield self.store.update_federation_out_pos( - "federation", self.federation_position - ) + try: + self.federation_position = token + + # We linearize here to ensure we don't have races updating the token + with (yield self._fed_position_linearizer.queue(None)): + if self._last_ack < self.federation_position: + yield self.store.update_federation_out_pos( + "federation", self.federation_position + ) - # We ACK this token over replication so that the master can drop - # its in memory queues - self.replication_client.send_federation_ack(self.federation_position) - self._last_ack = self.federation_position + # We ACK this token over replication so that the master can drop + # its in memory queues + self.replication_client.send_federation_ack(self.federation_position) + self._last_ack = self.federation_position + except Exception: + logger.exception("Error updating federation stream position") if __name__ == '__main__': diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 816c080d18..b349e3e3ce 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -44,7 +44,7 @@ from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string from twisted.internet import defer, reactor -from twisted.web.resource import Resource +from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.frontend_proxy") @@ -90,7 +90,7 @@ class KeyUploadServlet(RestServlet): # They're actually trying to upload something, proxy to main synapse. # Pass through the auth headers, if any, in case the access token # is there. - auth_headers = request.requestHeaders.getRawHeaders("Authorization", []) + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", []) headers = { "Authorization": auth_headers, } @@ -142,7 +142,7 @@ class FrontendProxyServer(HomeServer): "/_matrix/client/api/v1": resource, }) - root_resource = create_resource_tree(resources, Resource()) + root_resource = create_resource_tree(resources, NoResource()) _base.listen_tcp( bind_addresses, diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index e477c7ced6..a0e465d644 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -48,6 +48,7 @@ from synapse.server import HomeServer from synapse.storage import are_all_users_on_domain from synapse.storage.engines import IncorrectDatabaseSetup, create_engine from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database +from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole @@ -56,7 +57,7 @@ from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string from twisted.application import service from twisted.internet import defer, reactor -from twisted.web.resource import EncodingResourceWrapper, Resource +from twisted.web.resource import EncodingResourceWrapper, NoResource from twisted.web.server import GzipEncoderFactory from twisted.web.static import File @@ -126,7 +127,7 @@ class SynapseHomeServer(HomeServer): if WEB_CLIENT_PREFIX in resources: root_resource = RootRedirect(WEB_CLIENT_PREFIX) else: - root_resource = Resource() + root_resource = NoResource() root_resource = create_resource_tree(resources, root_resource) @@ -402,6 +403,10 @@ def run(hs): stats = {} + # Contains the list of processes we will be monitoring + # currently either 0 or 1 + stats_process = [] + @defer.inlineCallbacks def phone_stats_home(): logger.info("Gathering stats for reporting") @@ -425,8 +430,21 @@ def run(hs): stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms() stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() + r30_results = yield hs.get_datastore().count_r30_users() + for name, count in r30_results.iteritems(): + stats["r30_users_" + name] = count + daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() stats["daily_sent_messages"] = daily_sent_messages + stats["cache_factor"] = CACHE_SIZE_FACTOR + stats["event_cache_size"] = hs.config.event_cache_size + + if len(stats_process) > 0: + stats["memory_rss"] = 0 + stats["cpu_average"] = 0 + for process in stats_process: + stats["memory_rss"] += process.memory_info().rss + stats["cpu_average"] += int(process.cpu_percent(interval=None)) logger.info("Reporting stats to matrix.org: %s" % (stats,)) try: @@ -437,10 +455,32 @@ def run(hs): except Exception as e: logger.warn("Error reporting stats: %s", e) + def performance_stats_init(): + try: + import psutil + process = psutil.Process() + # Ensure we can fetch both, and make the initial request for cpu_percent + # so the next request will use this as the initial point. + process.memory_info().rss + process.cpu_percent(interval=None) + logger.info("report_stats can use psutil") + stats_process.append(process) + except (ImportError, AttributeError): + logger.warn( + "report_stats enabled but psutil is not installed or incorrect version." + " Disabling reporting of memory/cpu stats." + " Ensuring psutil is available will help matrix.org track performance" + " changes across releases." + ) + if hs.config.report_stats: logger.info("Scheduling stats reporting for 3 hour intervals") clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000) + # We need to defer this init for the cases that we daemonize + # otherwise the process ID we get is that of the non-daemon process + clock.call_later(0, performance_stats_init) + # We wait 5 minutes to send the first set of stats as the server can # be quite busy the first few minutes clock.call_later(5 * 60, phone_stats_home) diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 84c5791b3b..fc8282bbc1 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -43,7 +43,7 @@ from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string from twisted.internet import reactor -from twisted.web.resource import Resource +from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.media_repository") @@ -84,7 +84,7 @@ class MediaRepositoryServer(HomeServer): ), }) - root_resource = create_resource_tree(resources, Resource()) + root_resource = create_resource_tree(resources, NoResource()) _base.listen_tcp( bind_addresses, diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 98a4a7c62c..26930d1b3b 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -33,11 +33,11 @@ from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, preserve_fn +from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string from twisted.internet import defer, reactor -from twisted.web.resource import Resource +from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.pusher") @@ -94,7 +94,7 @@ class PusherServer(HomeServer): if name == "metrics": resources[METRICS_PREFIX] = MetricsResource(self) - root_resource = create_resource_tree(resources, Resource()) + root_resource = create_resource_tree(resources, NoResource()) _base.listen_tcp( bind_addresses, @@ -140,24 +140,27 @@ class PusherReplicationHandler(ReplicationClientHandler): def on_rdata(self, stream_name, token, rows): super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) - preserve_fn(self.poke_pushers)(stream_name, token, rows) + run_in_background(self.poke_pushers, stream_name, token, rows) @defer.inlineCallbacks def poke_pushers(self, stream_name, token, rows): - if stream_name == "pushers": - for row in rows: - if row.deleted: - yield self.stop_pusher(row.user_id, row.app_id, row.pushkey) - else: - yield self.start_pusher(row.user_id, row.app_id, row.pushkey) - elif stream_name == "events": - yield self.pusher_pool.on_new_notifications( - token, token, - ) - elif stream_name == "receipts": - yield self.pusher_pool.on_new_receipts( - token, token, set(row.room_id for row in rows) - ) + try: + if stream_name == "pushers": + for row in rows: + if row.deleted: + yield self.stop_pusher(row.user_id, row.app_id, row.pushkey) + else: + yield self.start_pusher(row.user_id, row.app_id, row.pushkey) + elif stream_name == "events": + yield self.pusher_pool.on_new_notifications( + token, token, + ) + elif stream_name == "receipts": + yield self.pusher_pool.on_new_receipts( + token, token, set(row.room_id for row in rows) + ) + except Exception: + logger.exception("Error poking pushers") def stop_pusher(self, user_id, app_id, pushkey): key = "%s:%s" % (app_id, pushkey) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index abe91dcfbd..7152b1deb4 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -51,12 +51,14 @@ from synapse.storage.engines import create_engine from synapse.storage.presence import UserPresenceState from synapse.storage.roommember import RoomMemberStore from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, preserve_fn +from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole from synapse.util.stringutils import random_string from synapse.util.versionstring import get_version_string from twisted.internet import defer, reactor -from twisted.web.resource import Resource +from twisted.web.resource import NoResource + +from six import iteritems logger = logging.getLogger("synapse.app.synchrotron") @@ -211,7 +213,7 @@ class SynchrotronPresence(object): def get_currently_syncing_users(self): return [ - user_id for user_id, count in self.user_to_num_current_syncs.iteritems() + user_id for user_id, count in iteritems(self.user_to_num_current_syncs) if count > 0 ] @@ -269,7 +271,7 @@ class SynchrotronServer(HomeServer): "/_matrix/client/api/v1": resource, }) - root_resource = create_resource_tree(resources, Resource()) + root_resource = create_resource_tree(resources, NoResource()) _base.listen_tcp( bind_addresses, @@ -325,8 +327,7 @@ class SyncReplicationHandler(ReplicationClientHandler): def on_rdata(self, stream_name, token, rows): super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) - - preserve_fn(self.process_and_notify)(stream_name, token, rows) + run_in_background(self.process_and_notify, stream_name, token, rows) def get_streams_to_replicate(self): args = super(SyncReplicationHandler, self).get_streams_to_replicate() @@ -338,55 +339,58 @@ class SyncReplicationHandler(ReplicationClientHandler): @defer.inlineCallbacks def process_and_notify(self, stream_name, token, rows): - if stream_name == "events": - # We shouldn't get multiple rows per token for events stream, so - # we don't need to optimise this for multiple rows. - for row in rows: - event = yield self.store.get_event(row.event_id) - extra_users = () - if event.type == EventTypes.Member: - extra_users = (event.state_key,) - max_token = self.store.get_room_max_stream_ordering() - self.notifier.on_new_room_event( - event, token, max_token, extra_users + try: + if stream_name == "events": + # We shouldn't get multiple rows per token for events stream, so + # we don't need to optimise this for multiple rows. + for row in rows: + event = yield self.store.get_event(row.event_id) + extra_users = () + if event.type == EventTypes.Member: + extra_users = (event.state_key,) + max_token = self.store.get_room_max_stream_ordering() + self.notifier.on_new_room_event( + event, token, max_token, extra_users + ) + elif stream_name == "push_rules": + self.notifier.on_new_event( + "push_rules_key", token, users=[row.user_id for row in rows], ) - elif stream_name == "push_rules": - self.notifier.on_new_event( - "push_rules_key", token, users=[row.user_id for row in rows], - ) - elif stream_name in ("account_data", "tag_account_data",): - self.notifier.on_new_event( - "account_data_key", token, users=[row.user_id for row in rows], - ) - elif stream_name == "receipts": - self.notifier.on_new_event( - "receipt_key", token, rooms=[row.room_id for row in rows], - ) - elif stream_name == "typing": - self.typing_handler.process_replication_rows(token, rows) - self.notifier.on_new_event( - "typing_key", token, rooms=[row.room_id for row in rows], - ) - elif stream_name == "to_device": - entities = [row.entity for row in rows if row.entity.startswith("@")] - if entities: + elif stream_name in ("account_data", "tag_account_data",): self.notifier.on_new_event( - "to_device_key", token, users=entities, + "account_data_key", token, users=[row.user_id for row in rows], ) - elif stream_name == "device_lists": - all_room_ids = set() - for row in rows: - room_ids = yield self.store.get_rooms_for_user(row.user_id) - all_room_ids.update(room_ids) - self.notifier.on_new_event( - "device_list_key", token, rooms=all_room_ids, - ) - elif stream_name == "presence": - yield self.presence_handler.process_replication_rows(token, rows) - elif stream_name == "receipts": - self.notifier.on_new_event( - "groups_key", token, users=[row.user_id for row in rows], - ) + elif stream_name == "receipts": + self.notifier.on_new_event( + "receipt_key", token, rooms=[row.room_id for row in rows], + ) + elif stream_name == "typing": + self.typing_handler.process_replication_rows(token, rows) + self.notifier.on_new_event( + "typing_key", token, rooms=[row.room_id for row in rows], + ) + elif stream_name == "to_device": + entities = [row.entity for row in rows if row.entity.startswith("@")] + if entities: + self.notifier.on_new_event( + "to_device_key", token, users=entities, + ) + elif stream_name == "device_lists": + all_room_ids = set() + for row in rows: + room_ids = yield self.store.get_rooms_for_user(row.user_id) + all_room_ids.update(room_ids) + self.notifier.on_new_event( + "device_list_key", token, rooms=all_room_ids, + ) + elif stream_name == "presence": + yield self.presence_handler.process_replication_rows(token, rows) + elif stream_name == "receipts": + self.notifier.on_new_event( + "groups_key", token, users=[row.user_id for row in rows], + ) + except Exception: + logger.exception("Error processing replication") def start(config_options): diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index 0f0ddfa78a..712dfa870e 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -38,7 +38,7 @@ def pid_running(pid): try: os.kill(pid, 0) return True - except OSError, err: + except OSError as err: if err.errno == errno.EPERM: return True return False @@ -98,7 +98,7 @@ def stop(pidfile, app): try: os.kill(pid, signal.SIGTERM) write("stopped %s" % (app,), colour=GREEN) - except OSError, err: + except OSError as err: if err.errno == errno.ESRCH: write("%s not running" % (app,), colour=YELLOW) elif err.errno == errno.EPERM: @@ -252,6 +252,7 @@ def main(): for running_pid in running_pids: while pid_running(running_pid): time.sleep(0.2) + write("All processes exited; now restarting...") if action == "start" or action == "restart": if start_stop_synapse: diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 494ccb702c..5ba7e9b416 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -39,11 +39,11 @@ from synapse.storage.engines import create_engine from synapse.storage.user_directory import UserDirectoryStore from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.logcontext import LoggingContext, preserve_fn +from synapse.util.logcontext import LoggingContext, run_in_background from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string -from twisted.internet import reactor -from twisted.web.resource import Resource +from twisted.internet import reactor, defer +from twisted.web.resource import NoResource logger = logging.getLogger("synapse.app.user_dir") @@ -116,7 +116,7 @@ class UserDirectoryServer(HomeServer): "/_matrix/client/api/v1": resource, }) - root_resource = create_resource_tree(resources, Resource()) + root_resource = create_resource_tree(resources, NoResource()) _base.listen_tcp( bind_addresses, @@ -164,7 +164,14 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler): stream_name, token, rows ) if stream_name == "current_state_deltas": - preserve_fn(self.user_directory.notify_new_event)() + run_in_background(self._notify_directory) + + @defer.inlineCallbacks + def _notify_directory(self): + try: + yield self.user_directory.notify_new_event() + except Exception: + logger.exception("Error notifiying user directory of state update") def start(config_options): diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index d5a7a5ce2f..5fdb579723 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -21,6 +21,8 @@ from twisted.internet import defer import logging import re +from six import string_types + logger = logging.getLogger(__name__) @@ -146,7 +148,7 @@ class ApplicationService(object): ) regex = regex_obj.get("regex") - if isinstance(regex, basestring): + if isinstance(regex, string_types): regex_obj["regex"] = re.compile(regex) # Pre-compile regex else: raise ValueError( diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 40c433d7ae..00efff1464 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -18,7 +18,6 @@ from synapse.api.constants import ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.http.client import SimpleHttpClient from synapse.events.utils import serialize_event -from synapse.util.logcontext import preserve_fn, make_deferred_yieldable from synapse.util.caches.response_cache import ResponseCache from synapse.types import ThirdPartyInstanceID @@ -73,7 +72,8 @@ class ApplicationServiceApi(SimpleHttpClient): super(ApplicationServiceApi, self).__init__(hs) self.clock = hs.get_clock() - self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS) + self.protocol_meta_cache = ResponseCache(hs, "as_protocol_meta", + timeout_ms=HOUR_IN_MS) @defer.inlineCallbacks def query_user(self, service, user_id): @@ -193,12 +193,7 @@ class ApplicationServiceApi(SimpleHttpClient): defer.returnValue(None) key = (service.id, protocol) - result = self.protocol_meta_cache.get(key) - if not result: - result = self.protocol_meta_cache.set( - key, preserve_fn(_get)() - ) - return make_deferred_yieldable(result) + return self.protocol_meta_cache.wrap(key, _get) @defer.inlineCallbacks def push_bulk(self, service, events, txn_id=None): diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 6da315473d..6eddbc0828 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -51,7 +51,7 @@ components. from twisted.internet import defer from synapse.appservice import ApplicationServiceState -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import run_in_background from synapse.util.metrics import Measure import logging @@ -106,7 +106,7 @@ class _ServiceQueuer(object): def enqueue(self, service, event): # if this service isn't being sent something self.queued_events.setdefault(service.id, []).append(event) - preserve_fn(self._send_request)(service) + run_in_background(self._send_request, service) @defer.inlineCallbacks def _send_request(self, service): @@ -152,10 +152,10 @@ class _TransactionController(object): if sent: yield txn.complete(self.store) else: - preserve_fn(self._start_recoverer)(service) - except Exception as e: - logger.exception(e) - preserve_fn(self._start_recoverer)(service) + run_in_background(self._start_recoverer, service) + except Exception: + logger.exception("Error creating appservice transaction") + run_in_background(self._start_recoverer, service) @defer.inlineCallbacks def on_recovered(self, recoverer): @@ -176,17 +176,20 @@ class _TransactionController(object): @defer.inlineCallbacks def _start_recoverer(self, service): - yield self.store.set_appservice_state( - service, - ApplicationServiceState.DOWN - ) - logger.info( - "Application service falling behind. Starting recoverer. AS ID %s", - service.id - ) - recoverer = self.recoverer_fn(service, self.on_recovered) - self.add_recoverers([recoverer]) - recoverer.recover() + try: + yield self.store.set_appservice_state( + service, + ApplicationServiceState.DOWN + ) + logger.info( + "Application service falling behind. Starting recoverer. AS ID %s", + service.id + ) + recoverer = self.recoverer_fn(service, self.on_recovered) + self.add_recoverers([recoverer]) + recoverer.recover() + except Exception: + logger.exception("Error starting AS recoverer") @defer.inlineCallbacks def _is_service_up(self, service): diff --git a/synapse/config/_base.py b/synapse/config/_base.py index fa105bce72..b748ed2b0a 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -19,6 +19,8 @@ import os import yaml from textwrap import dedent +from six import integer_types + class ConfigError(Exception): pass @@ -49,7 +51,7 @@ Missing mandatory `server_name` config option. class Config(object): @staticmethod def parse_size(value): - if isinstance(value, int) or isinstance(value, long): + if isinstance(value, integer_types): return value sizes = {"K": 1024, "M": 1024 * 1024} size = 1 @@ -61,7 +63,7 @@ class Config(object): @staticmethod def parse_duration(value): - if isinstance(value, int) or isinstance(value, long): + if isinstance(value, integer_types): return value second = 1000 minute = 60 * second @@ -279,31 +281,31 @@ class Config(object): ) if not cls.path_exists(config_dir_path): os.makedirs(config_dir_path) - with open(config_path, "wb") as config_file: - config_bytes, config = obj.generate_config( + with open(config_path, "w") as config_file: + config_str, config = obj.generate_config( config_dir_path=config_dir_path, server_name=server_name, report_stats=(config_args.report_stats == "yes"), is_generating_file=True ) obj.invoke_all("generate_files", config) - config_file.write(config_bytes) - print ( + config_file.write(config_str) + print(( "A config file has been generated in %r for server name" " %r with corresponding SSL keys and self-signed" " certificates. Please review this file and customise it" " to your needs." - ) % (config_path, server_name) - print ( + ) % (config_path, server_name)) + print( "If this server name is incorrect, you will need to" " regenerate the SSL certificates" ) return else: - print ( + print(( "Config file %r already exists. Generating any missing key" " files." - ) % (config_path,) + ) % (config_path,)) generate_keys = True parser = argparse.ArgumentParser( diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index aba0aec6e8..277305e184 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -17,10 +17,12 @@ from ._base import Config, ConfigError from synapse.appservice import ApplicationService from synapse.types import UserID -import urllib import yaml import logging +from six import string_types +from six.moves.urllib import parse as urlparse + logger = logging.getLogger(__name__) @@ -89,21 +91,21 @@ def _load_appservice(hostname, as_info, config_filename): "id", "as_token", "hs_token", "sender_localpart" ] for field in required_string_fields: - if not isinstance(as_info.get(field), basestring): + if not isinstance(as_info.get(field), string_types): raise KeyError("Required string field: '%s' (%s)" % ( field, config_filename, )) # 'url' must either be a string or explicitly null, not missing # to avoid accidentally turning off push for ASes. - if (not isinstance(as_info.get("url"), basestring) and + if (not isinstance(as_info.get("url"), string_types) and as_info.get("url", "") is not None): raise KeyError( "Required string field or explicit null: 'url' (%s)" % (config_filename,) ) localpart = as_info["sender_localpart"] - if urllib.quote(localpart) != localpart: + if urlparse.quote(localpart) != localpart: raise ValueError( "sender_localpart needs characters which are not URL encoded." ) @@ -128,7 +130,7 @@ def _load_appservice(hostname, as_info, config_filename): "Expected namespace entry in %s to be an object," " but got %s", ns, regex_obj ) - if not isinstance(regex_obj.get("regex"), basestring): + if not isinstance(regex_obj.get("regex"), string_types): raise ValueError( "Missing/bad type 'regex' key in %s", regex_obj ) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 3f70039acd..6a7228dc2f 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -117,7 +117,7 @@ class LoggingConfig(Config): log_config = config.get("log_config") if log_config and not os.path.exists(log_config): log_file = self.abspath("homeserver.log") - with open(log_config, "wb") as log_config_file: + with open(log_config, "w") as log_config_file: log_config_file.write( DEFAULT_LOG_CONFIG.substitute(log_file=log_file) ) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 336959094b..c5384b3ad4 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -77,7 +77,9 @@ class RegistrationConfig(Config): # Set the number of bcrypt rounds used to generate password hash. # Larger numbers increase the work factor needed to generate the hash. - # The default number of rounds is 12. + # The default number is 12 (which equates to 2^12 rounds). + # N.B. that increasing this will exponentially increase the time required + # to register or login - e.g. 24 => 2^24 rounds which will take >20 mins. bcrypt_rounds: 12 # Allows users to register as guests without a password/email/etc, and diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 29eb012ddb..b66154bc7c 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -133,7 +133,7 @@ class TlsConfig(Config): tls_dh_params_path = config["tls_dh_params_path"] if not self.path_exists(tls_private_key_path): - with open(tls_private_key_path, "w") as private_key_file: + with open(tls_private_key_path, "wb") as private_key_file: tls_private_key = crypto.PKey() tls_private_key.generate_key(crypto.TYPE_RSA, 2048) private_key_pem = crypto.dump_privatekey( @@ -148,7 +148,7 @@ class TlsConfig(Config): ) if not self.path_exists(tls_certificate_path): - with open(tls_certificate_path, "w") as certificate_file: + with open(tls_certificate_path, "wb") as certificate_file: cert = crypto.X509() subject = cert.get_subject() subject.CN = config["server_name"] diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index cff3ca809a..0397f73ab4 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -13,8 +13,8 @@ # limitations under the License. from twisted.internet import ssl -from OpenSSL import SSL -from twisted.internet._sslverify import _OpenSSLECCurve, _defaultCurveName +from OpenSSL import SSL, crypto +from twisted.internet._sslverify import _defaultCurveName import logging @@ -32,8 +32,9 @@ class ServerContextFactory(ssl.ContextFactory): @staticmethod def configure_context(context, config): try: - _ecCurve = _OpenSSLECCurve(_defaultCurveName) - _ecCurve.addECKeyToContext(context) + _ecCurve = crypto.get_elliptic_curve(_defaultCurveName) + context.set_tmp_ecdh(_ecCurve) + except Exception: logger.exception("Failed to enable elliptic curve for TLS") context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 35f810b07b..22ee0fc93f 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -19,7 +19,8 @@ from synapse.api.errors import SynapseError, Codes from synapse.util import unwrapFirstError, logcontext from synapse.util.logcontext import ( PreserveLoggingContext, - preserve_fn + preserve_fn, + run_in_background, ) from synapse.util.metrics import Measure @@ -127,7 +128,7 @@ class Keyring(object): verify_requests.append(verify_request) - preserve_fn(self._start_key_lookups)(verify_requests) + run_in_background(self._start_key_lookups, verify_requests) # Pass those keys to handle_key_deferred so that the json object # signatures can be verified @@ -146,53 +147,56 @@ class Keyring(object): verify_requests (List[VerifyKeyRequest]): """ - # create a deferred for each server we're going to look up the keys - # for; we'll resolve them once we have completed our lookups. - # These will be passed into wait_for_previous_lookups to block - # any other lookups until we have finished. - # The deferreds are called with no logcontext. - server_to_deferred = { - rq.server_name: defer.Deferred() - for rq in verify_requests - } - - # We want to wait for any previous lookups to complete before - # proceeding. - yield self.wait_for_previous_lookups( - [rq.server_name for rq in verify_requests], - server_to_deferred, - ) - - # Actually start fetching keys. - self._get_server_verify_keys(verify_requests) - - # When we've finished fetching all the keys for a given server_name, - # resolve the deferred passed to `wait_for_previous_lookups` so that - # any lookups waiting will proceed. - # - # map from server name to a set of request ids - server_to_request_ids = {} - - for verify_request in verify_requests: - server_name = verify_request.server_name - request_id = id(verify_request) - server_to_request_ids.setdefault(server_name, set()).add(request_id) - - def remove_deferreds(res, verify_request): - server_name = verify_request.server_name - request_id = id(verify_request) - server_to_request_ids[server_name].discard(request_id) - if not server_to_request_ids[server_name]: - d = server_to_deferred.pop(server_name, None) - if d: - d.callback(None) - return res - - for verify_request in verify_requests: - verify_request.deferred.addBoth( - remove_deferreds, verify_request, + try: + # create a deferred for each server we're going to look up the keys + # for; we'll resolve them once we have completed our lookups. + # These will be passed into wait_for_previous_lookups to block + # any other lookups until we have finished. + # The deferreds are called with no logcontext. + server_to_deferred = { + rq.server_name: defer.Deferred() + for rq in verify_requests + } + + # We want to wait for any previous lookups to complete before + # proceeding. + yield self.wait_for_previous_lookups( + [rq.server_name for rq in verify_requests], + server_to_deferred, ) + # Actually start fetching keys. + self._get_server_verify_keys(verify_requests) + + # When we've finished fetching all the keys for a given server_name, + # resolve the deferred passed to `wait_for_previous_lookups` so that + # any lookups waiting will proceed. + # + # map from server name to a set of request ids + server_to_request_ids = {} + + for verify_request in verify_requests: + server_name = verify_request.server_name + request_id = id(verify_request) + server_to_request_ids.setdefault(server_name, set()).add(request_id) + + def remove_deferreds(res, verify_request): + server_name = verify_request.server_name + request_id = id(verify_request) + server_to_request_ids[server_name].discard(request_id) + if not server_to_request_ids[server_name]: + d = server_to_deferred.pop(server_name, None) + if d: + d.callback(None) + return res + + for verify_request in verify_requests: + verify_request.deferred.addBoth( + remove_deferreds, verify_request, + ) + except Exception: + logger.exception("Error starting key lookups") + @defer.inlineCallbacks def wait_for_previous_lookups(self, server_names, server_to_deferred): """Waits for any previous key lookups for the given servers to finish. @@ -313,7 +317,7 @@ class Keyring(object): if not verify_request.deferred.called: verify_request.deferred.errback(err) - preserve_fn(do_iterations)().addErrback(on_err) + run_in_background(do_iterations).addErrback(on_err) @defer.inlineCallbacks def get_keys_from_store(self, server_name_and_key_ids): @@ -329,8 +333,9 @@ class Keyring(object): """ res = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.store.get_server_verify_keys)( - server_name, key_ids + run_in_background( + self.store.get_server_verify_keys, + server_name, key_ids, ).addCallback(lambda ks, server: (server, ks), server_name) for server_name, key_ids in server_name_and_key_ids ], @@ -352,13 +357,13 @@ class Keyring(object): logger.exception( "Unable to get key from %r: %s %s", perspective_name, - type(e).__name__, str(e.message), + type(e).__name__, str(e), ) defer.returnValue({}) results = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(get_key)(p_name, p_keys) + run_in_background(get_key, p_name, p_keys) for p_name, p_keys in self.perspective_servers.items() ], consumeErrors=True, @@ -384,7 +389,7 @@ class Keyring(object): logger.info( "Unable to get key %r for %r directly: %s %s", key_ids, server_name, - type(e).__name__, str(e.message), + type(e).__name__, str(e), ) if not keys: @@ -398,7 +403,7 @@ class Keyring(object): results = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(get_key)(server_name, key_ids) + run_in_background(get_key, server_name, key_ids) for server_name, key_ids in server_name_and_key_ids ], consumeErrors=True, @@ -481,7 +486,8 @@ class Keyring(object): yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.store_keys)( + run_in_background( + self.store_keys, server_name=server_name, from_server=perspective_name, verify_keys=response_keys, @@ -539,7 +545,8 @@ class Keyring(object): yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.store_keys)( + run_in_background( + self.store_keys, server_name=key_server_name, from_server=server_name, verify_keys=verify_keys, @@ -615,7 +622,8 @@ class Keyring(object): yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.store.store_server_keys_json)( + run_in_background( + self.store.store_server_keys_json, server_name=server_name, key_id=key_id, from_server=server_name, @@ -716,7 +724,8 @@ class Keyring(object): # TODO(markjh): Store whether the keys have expired. return logcontext.make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self.store.store_server_verify_key)( + run_in_background( + self.store.store_server_verify_key, server_name, server_name, key.time_added, key ) for key_id, key in verify_keys.items() @@ -734,7 +743,7 @@ def _handle_key_deferred(verify_request): except IOError as e: logger.warn( "Got IOError when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), + server_name, type(e).__name__, str(e), ) raise SynapseError( 502, @@ -744,7 +753,7 @@ def _handle_key_deferred(verify_request): except Exception as e: logger.exception( "Got Exception when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), + server_name, type(e).__name__, str(e), ) raise SynapseError( 401, diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index e673e96cc0..c3ff85c49a 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -47,14 +47,26 @@ class _EventInternalMetadata(object): def _event_dict_property(key): + # We want to be able to use hasattr with the event dict properties. + # However, (on python3) hasattr expects AttributeError to be raised. Hence, + # we need to transform the KeyError into an AttributeError def getter(self): - return self._event_dict[key] + try: + return self._event_dict[key] + except KeyError: + raise AttributeError(key) def setter(self, v): - self._event_dict[key] = v + try: + self._event_dict[key] = v + except KeyError: + raise AttributeError(key) def delete(self): - del self._event_dict[key] + try: + del self._event_dict[key] + except KeyError: + raise AttributeError(key) return property( getter, diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 79eaa31031..4cc98a3fe8 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -14,7 +14,10 @@ # limitations under the License. import logging -from synapse.api.errors import SynapseError +import six + +from synapse.api.constants import MAX_DEPTH +from synapse.api.errors import SynapseError, Codes from synapse.crypto.event_signing import check_event_content_hash from synapse.events import FrozenEvent from synapse.events.utils import prune_event @@ -190,11 +193,23 @@ def event_from_pdu_json(pdu_json, outlier=False): FrozenEvent Raises: - SynapseError: if the pdu is missing required fields + SynapseError: if the pdu is missing required fields or is otherwise + not a valid matrix event """ # we could probably enforce a bunch of other fields here (room_id, sender, # origin, etc etc) - assert_params_in_request(pdu_json, ('event_id', 'type')) + assert_params_in_request(pdu_json, ('event_id', 'type', 'depth')) + + depth = pdu_json['depth'] + if not isinstance(depth, six.integer_types): + raise SynapseError(400, "Depth %r not an intger" % (depth, ), + Codes.BAD_JSON) + + if depth < 0: + raise SynapseError(400, "Depth too small", Codes.BAD_JSON) + elif depth > MAX_DEPTH: + raise SynapseError(400, "Depth too large", Codes.BAD_JSON) + event = FrozenEvent( pdu_json ) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 38440da5b5..6163f7c466 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -19,6 +19,8 @@ import itertools import logging import random +from six.moves import range + from twisted.internet import defer from synapse.api.constants import Membership @@ -33,7 +35,7 @@ from synapse.federation.federation_base import ( import synapse.metrics from synapse.util import logcontext, unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.logcontext import make_deferred_yieldable, preserve_fn +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logutils import log_function from synapse.util.retryutils import NotRetryingDestination @@ -394,7 +396,7 @@ class FederationClient(FederationBase): seen_events = yield self.store.get_events(event_ids, allow_rejected=True) signed_events = seen_events.values() else: - seen_events = yield self.store.have_events(event_ids) + seen_events = yield self.store.have_seen_events(event_ids) signed_events = [] failed_to_fetch = set() @@ -413,11 +415,12 @@ class FederationClient(FederationBase): batch_size = 20 missing_events = list(missing_events) - for i in xrange(0, len(missing_events), batch_size): + for i in range(0, len(missing_events), batch_size): batch = set(missing_events[i:i + batch_size]) deferreds = [ - preserve_fn(self.get_pdu)( + run_in_background( + self.get_pdu, destinations=random_server_list(), event_id=e_id, ) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index bea7fd0b71..247ddc89d5 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,9 +31,10 @@ import synapse.metrics from synapse.types import get_domain_from_id from synapse.util import async from synapse.util.caches.response_cache import ResponseCache -from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.logutils import log_function +from six import iteritems + # when processing incoming transactions, we try to handle multiple rooms in # parallel, up to this limit. TRANSACTION_CONCURRENCY_LIMIT = 10 @@ -65,7 +67,7 @@ class FederationServer(FederationBase): # We cache responses to state queries, as they take a while and often # come in waves. - self._state_resp_cache = ResponseCache(hs, timeout_ms=30000) + self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) @defer.inlineCallbacks @log_function @@ -212,16 +214,17 @@ class FederationServer(FederationBase): if not in_room: raise AuthError(403, "Host not in room.") - result = self._state_resp_cache.get((room_id, event_id)) - if not result: - with (yield self._server_linearizer.queue((origin, room_id))): - d = self._state_resp_cache.set( - (room_id, event_id), - preserve_fn(self._on_context_state_request_compute)(room_id, event_id) - ) - resp = yield make_deferred_yieldable(d) - else: - resp = yield make_deferred_yieldable(result) + # we grab the linearizer to protect ourselves from servers which hammer + # us. In theory we might already have the response to this query + # in the cache so we could return it without waiting for the linearizer + # - but that's non-trivial to get right, and anyway somewhat defeats + # the point of the linearizer. + with (yield self._server_linearizer.queue((origin, room_id))): + resp = yield self._state_resp_cache.wrap( + (room_id, event_id), + self._on_context_state_request_compute, + room_id, event_id, + ) defer.returnValue((200, resp)) @@ -425,9 +428,9 @@ class FederationServer(FederationBase): "Claimed one-time-keys: %s", ",".join(( "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in json_result.iteritems() - for device_id, device_keys in user_keys.iteritems() - for key_id, _ in device_keys.iteritems() + for user_id, user_keys in iteritems(json_result) + for device_id, device_keys in iteritems(user_keys) + for key_id, _ in iteritems(device_keys) )), ) @@ -494,13 +497,33 @@ class FederationServer(FederationBase): def _handle_received_pdu(self, origin, pdu): """ Process a PDU received in a federation /send/ transaction. + If the event is invalid, then this method throws a FederationError. + (The error will then be logged and sent back to the sender (which + probably won't do anything with it), and other events in the + transaction will be processed as normal). + + It is likely that we'll then receive other events which refer to + this rejected_event in their prev_events, etc. When that happens, + we'll attempt to fetch the rejected event again, which will presumably + fail, so those second-generation events will also get rejected. + + Eventually, we get to the point where there are more than 10 events + between any new events and the original rejected event. Since we + only try to backfill 10 events deep on received pdu, we then accept the + new event, possibly introducing a discontinuity in the DAG, with new + forward extremities, so normal service is approximately returned, + until we try to backfill across the discontinuity. + Args: origin (str): server which sent the pdu pdu (FrozenEvent): received pdu Returns (Deferred): completes with None - Raises: FederationError if the signatures / hash do not match - """ + + Raises: FederationError if the signatures / hash do not match, or + if the event was unacceptable for any other reason (eg, too large, + too many prev_events, couldn't find the prev_events) + """ # check that it's actually being sent from a valid destination to # workaround bug #1753 in 0.18.5 and 0.18.6 if origin != get_domain_from_id(pdu.event_id): diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 93e5acebc1..0f0c687b37 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -40,6 +40,8 @@ from collections import namedtuple import logging +from six import itervalues, iteritems + logger = logging.getLogger(__name__) @@ -122,7 +124,7 @@ class FederationRemoteSendQueue(object): user_ids = set( user_id - for uids in self.presence_changed.itervalues() + for uids in itervalues(self.presence_changed) for user_id in uids ) @@ -276,7 +278,7 @@ class FederationRemoteSendQueue(object): # stream position. keyed_edus = {self.keyed_edu_changed[k]: k for k in keys[i:j]} - for ((destination, edu_key), pos) in keyed_edus.iteritems(): + for ((destination, edu_key), pos) in iteritems(keyed_edus): rows.append((pos, KeyedEduRow( key=edu_key, edu=self.keyed_edu[(destination, edu_key)], @@ -309,7 +311,7 @@ class FederationRemoteSendQueue(object): j = keys.bisect_right(to_token) + 1 device_messages = {self.device_messages[k]: k for k in keys[i:j]} - for (destination, pos) in device_messages.iteritems(): + for (destination, pos) in iteritems(device_messages): rows.append((pos, DeviceRow( destination=destination, ))) @@ -528,19 +530,19 @@ def process_rows_for_federation(transaction_queue, rows): if buff.presence: transaction_queue.send_presence(buff.presence) - for destination, edu_map in buff.keyed_edus.iteritems(): + for destination, edu_map in iteritems(buff.keyed_edus): for key, edu in edu_map.items(): transaction_queue.send_edu( edu.destination, edu.edu_type, edu.content, key=key, ) - for destination, edu_list in buff.edus.iteritems(): + for destination, edu_list in iteritems(buff.edus): for edu in edu_list: transaction_queue.send_edu( edu.destination, edu.edu_type, edu.content, key=None, ) - for destination, failure_list in buff.failures.iteritems(): + for destination, failure_list in iteritems(buff.failures): for failure in failure_list: transaction_queue.send_failure(destination, failure) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index a141ec9953..ded2b1871a 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -169,7 +169,7 @@ class TransactionQueue(object): while True: last_token = yield self.store.get_federation_out_pos("events") next_token, events = yield self.store.get_all_new_events_stream( - last_token, self._last_poked_id, limit=20, + last_token, self._last_poked_id, limit=100, ) logger.debug("Handling %s -> %s", last_token, next_token) @@ -177,24 +177,33 @@ class TransactionQueue(object): if not events and next_token >= self._last_poked_id: break - for event in events: + @defer.inlineCallbacks + def handle_event(event): # Only send events for this server. send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of() is_mine = self.is_mine_id(event.event_id) if not is_mine and send_on_behalf_of is None: - continue - - # Get the state from before the event. - # We need to make sure that this is the state from before - # the event and not from after it. - # Otherwise if the last member on a server in a room is - # banned then it won't receive the event because it won't - # be in the room after the ban. - destinations = yield self.state.get_current_hosts_in_room( - event.room_id, latest_event_ids=[ - prev_id for prev_id, _ in event.prev_events - ], - ) + return + + try: + # Get the state from before the event. + # We need to make sure that this is the state from before + # the event and not from after it. + # Otherwise if the last member on a server in a room is + # banned then it won't receive the event because it won't + # be in the room after the ban. + destinations = yield self.state.get_current_hosts_in_room( + event.room_id, latest_event_ids=[ + prev_id for prev_id, _ in event.prev_events + ], + ) + except Exception: + logger.exception( + "Failed to calculate hosts in room for event: %s", + event.event_id, + ) + return + destinations = set(destinations) if send_on_behalf_of is not None: @@ -207,12 +216,44 @@ class TransactionQueue(object): self._send_pdu(event, destinations) - events_processed_counter.inc_by(len(events)) + @defer.inlineCallbacks + def handle_room_events(events): + for event in events: + yield handle_event(event) + + events_by_room = {} + for event in events: + events_by_room.setdefault(event.room_id, []).append(event) + + yield logcontext.make_deferred_yieldable(defer.gatherResults( + [ + logcontext.run_in_background(handle_room_events, evs) + for evs in events_by_room.itervalues() + ], + consumeErrors=True + )) yield self.store.update_federation_out_pos( "events", next_token ) + if events: + now = self.clock.time_msec() + ts = yield self.store.get_received_ts(events[-1].event_id) + + synapse.metrics.event_processing_lag.set( + now - ts, "federation_sender", + ) + synapse.metrics.event_processing_last_ts.set( + ts, "federation_sender", + ) + + events_processed_counter.inc_by(len(events)) + + synapse.metrics.event_processing_positions.set( + next_token, "federation_sender", + ) + finally: self._is_processing = False @@ -282,6 +323,8 @@ class TransactionQueue(object): break yield self._process_presence_inner(states_map.values()) + except Exception: + logger.exception("Error sending presence states to servers") finally: self._processing_pending_presence = False diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 5488e82985..6db8efa6dd 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +21,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.util.logutils import log_function import logging +import urllib logger = logging.getLogger(__name__) @@ -49,7 +51,7 @@ class TransportLayerClient(object): logger.debug("get_room_state dest=%s, room=%s", destination, room_id) - path = PREFIX + "/state/%s/" % room_id + path = _create_path(PREFIX, "/state/%s/", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, ) @@ -71,7 +73,7 @@ class TransportLayerClient(object): logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) - path = PREFIX + "/state_ids/%s/" % room_id + path = _create_path(PREFIX, "/state_ids/%s/", room_id) return self.client.get_json( destination, path=path, args={"event_id": event_id}, ) @@ -93,7 +95,7 @@ class TransportLayerClient(object): logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) - path = PREFIX + "/event/%s/" % (event_id, ) + path = _create_path(PREFIX, "/event/%s/", event_id) return self.client.get_json(destination, path=path, timeout=timeout) @log_function @@ -119,7 +121,7 @@ class TransportLayerClient(object): # TODO: raise? return - path = PREFIX + "/backfill/%s/" % (room_id,) + path = _create_path(PREFIX, "/backfill/%s/", room_id) args = { "v": event_tuples, @@ -157,9 +159,11 @@ class TransportLayerClient(object): # generated by the json_data_callback. json_data = transaction.get_dict() + path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id) + response = yield self.client.put_json( transaction.destination, - path=PREFIX + "/send/%s/" % transaction.transaction_id, + path=path, data=json_data, json_data_callback=json_data_callback, long_retries=True, @@ -177,7 +181,7 @@ class TransportLayerClient(object): @log_function def make_query(self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False): - path = PREFIX + "/query/%s" % query_type + path = _create_path(PREFIX, "/query/%s", query_type) content = yield self.client.get_json( destination=destination, @@ -222,7 +226,7 @@ class TransportLayerClient(object): "make_membership_event called with membership='%s', must be one of %s" % (membership, ",".join(valid_memberships)) ) - path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id) + path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id) ignore_backoff = False retry_on_dns_fail = False @@ -248,7 +252,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_join(self, destination, room_id, event_id, content): - path = PREFIX + "/send_join/%s/%s" % (room_id, event_id) + path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id) response = yield self.client.put_json( destination=destination, @@ -261,7 +265,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_leave(self, destination, room_id, event_id, content): - path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id) + path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id) response = yield self.client.put_json( destination=destination, @@ -280,7 +284,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_invite(self, destination, room_id, event_id, content): - path = PREFIX + "/invite/%s/%s" % (room_id, event_id) + path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id) response = yield self.client.put_json( destination=destination, @@ -322,7 +326,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def exchange_third_party_invite(self, destination, room_id, event_dict): - path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,) + path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,) response = yield self.client.put_json( destination=destination, @@ -335,7 +339,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def get_event_auth(self, destination, room_id, event_id): - path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id) + path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id) content = yield self.client.get_json( destination=destination, @@ -347,7 +351,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def send_query_auth(self, destination, room_id, event_id, content): - path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id) + path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id) content = yield self.client.post_json( destination=destination, @@ -409,7 +413,7 @@ class TransportLayerClient(object): Returns: A dict containg the device keys. """ - path = PREFIX + "/user/devices/" + user_id + path = _create_path(PREFIX, "/user/devices/%s", user_id) content = yield self.client.get_json( destination=destination, @@ -459,7 +463,7 @@ class TransportLayerClient(object): @log_function def get_missing_events(self, destination, room_id, earliest_events, latest_events, limit, min_depth, timeout): - path = PREFIX + "/get_missing_events/%s" % (room_id,) + path = _create_path(PREFIX, "/get_missing_events/%s", room_id,) content = yield self.client.post_json( destination=destination, @@ -479,7 +483,7 @@ class TransportLayerClient(object): def get_group_profile(self, destination, group_id, requester_user_id): """Get a group profile """ - path = PREFIX + "/groups/%s/profile" % (group_id,) + path = _create_path(PREFIX, "/groups/%s/profile", group_id,) return self.client.get_json( destination=destination, @@ -498,7 +502,7 @@ class TransportLayerClient(object): requester_user_id (str) content (dict): The new profile of the group """ - path = PREFIX + "/groups/%s/profile" % (group_id,) + path = _create_path(PREFIX, "/groups/%s/profile", group_id,) return self.client.post_json( destination=destination, @@ -512,7 +516,7 @@ class TransportLayerClient(object): def get_group_summary(self, destination, group_id, requester_user_id): """Get a group summary """ - path = PREFIX + "/groups/%s/summary" % (group_id,) + path = _create_path(PREFIX, "/groups/%s/summary", group_id,) return self.client.get_json( destination=destination, @@ -525,7 +529,7 @@ class TransportLayerClient(object): def get_rooms_in_group(self, destination, group_id, requester_user_id): """Get all rooms in a group """ - path = PREFIX + "/groups/%s/rooms" % (group_id,) + path = _create_path(PREFIX, "/groups/%s/rooms", group_id,) return self.client.get_json( destination=destination, @@ -538,7 +542,7 @@ class TransportLayerClient(object): content): """Add a room to a group """ - path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,) + path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,) return self.client.post_json( destination=destination, @@ -552,7 +556,10 @@ class TransportLayerClient(object): config_key, content): """Update room in group """ - path = PREFIX + "/groups/%s/room/%s/config/%s" % (group_id, room_id, config_key,) + path = _create_path( + PREFIX, "/groups/%s/room/%s/config/%s", + group_id, room_id, config_key, + ) return self.client.post_json( destination=destination, @@ -565,7 +572,7 @@ class TransportLayerClient(object): def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): """Remove a room from a group """ - path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,) + path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,) return self.client.delete_json( destination=destination, @@ -578,7 +585,7 @@ class TransportLayerClient(object): def get_users_in_group(self, destination, group_id, requester_user_id): """Get users in a group """ - path = PREFIX + "/groups/%s/users" % (group_id,) + path = _create_path(PREFIX, "/groups/%s/users", group_id,) return self.client.get_json( destination=destination, @@ -591,7 +598,7 @@ class TransportLayerClient(object): def get_invited_users_in_group(self, destination, group_id, requester_user_id): """Get users that have been invited to a group """ - path = PREFIX + "/groups/%s/invited_users" % (group_id,) + path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,) return self.client.get_json( destination=destination, @@ -604,7 +611,23 @@ class TransportLayerClient(object): def accept_group_invite(self, destination, group_id, user_id, content): """Accept a group invite """ - path = PREFIX + "/groups/%s/users/%s/accept_invite" % (group_id, user_id) + path = _create_path( + PREFIX, "/groups/%s/users/%s/accept_invite", + group_id, user_id, + ) + + return self.client.post_json( + destination=destination, + path=path, + data=content, + ignore_backoff=True, + ) + + @log_function + def join_group(self, destination, group_id, user_id, content): + """Attempts to join a group + """ + path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id) return self.client.post_json( destination=destination, @@ -617,7 +640,7 @@ class TransportLayerClient(object): def invite_to_group(self, destination, group_id, user_id, requester_user_id, content): """Invite a user to a group """ - path = PREFIX + "/groups/%s/users/%s/invite" % (group_id, user_id) + path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id) return self.client.post_json( destination=destination, @@ -633,7 +656,7 @@ class TransportLayerClient(object): invited. """ - path = PREFIX + "/groups/local/%s/users/%s/invite" % (group_id, user_id) + path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id) return self.client.post_json( destination=destination, @@ -647,7 +670,7 @@ class TransportLayerClient(object): user_id, content): """Remove a user fron a group """ - path = PREFIX + "/groups/%s/users/%s/remove" % (group_id, user_id) + path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id) return self.client.post_json( destination=destination, @@ -664,7 +687,7 @@ class TransportLayerClient(object): kicked from the group. """ - path = PREFIX + "/groups/local/%s/users/%s/remove" % (group_id, user_id) + path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id) return self.client.post_json( destination=destination, @@ -679,7 +702,7 @@ class TransportLayerClient(object): the attestations """ - path = PREFIX + "/groups/%s/renew_attestation/%s" % (group_id, user_id) + path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id) return self.client.post_json( destination=destination, @@ -694,11 +717,12 @@ class TransportLayerClient(object): """Update a room entry in a group summary """ if category_id: - path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % ( + path = _create_path( + PREFIX, "/groups/%s/summary/categories/%s/rooms/%s", group_id, category_id, room_id, ) else: - path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,) + path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,) return self.client.post_json( destination=destination, @@ -714,11 +738,12 @@ class TransportLayerClient(object): """Delete a room entry in a group summary """ if category_id: - path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % ( + path = _create_path( + PREFIX + "/groups/%s/summary/categories/%s/rooms/%s", group_id, category_id, room_id, ) else: - path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,) + path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,) return self.client.delete_json( destination=destination, @@ -731,7 +756,7 @@ class TransportLayerClient(object): def get_group_categories(self, destination, group_id, requester_user_id): """Get all categories in a group """ - path = PREFIX + "/groups/%s/categories" % (group_id,) + path = _create_path(PREFIX, "/groups/%s/categories", group_id,) return self.client.get_json( destination=destination, @@ -744,7 +769,7 @@ class TransportLayerClient(object): def get_group_category(self, destination, group_id, requester_user_id, category_id): """Get category info in a group """ - path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,) + path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,) return self.client.get_json( destination=destination, @@ -758,7 +783,7 @@ class TransportLayerClient(object): content): """Update a category in a group """ - path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,) + path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,) return self.client.post_json( destination=destination, @@ -773,7 +798,7 @@ class TransportLayerClient(object): category_id): """Delete a category in a group """ - path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,) + path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,) return self.client.delete_json( destination=destination, @@ -786,7 +811,7 @@ class TransportLayerClient(object): def get_group_roles(self, destination, group_id, requester_user_id): """Get all roles in a group """ - path = PREFIX + "/groups/%s/roles" % (group_id,) + path = _create_path(PREFIX, "/groups/%s/roles", group_id,) return self.client.get_json( destination=destination, @@ -799,7 +824,7 @@ class TransportLayerClient(object): def get_group_role(self, destination, group_id, requester_user_id, role_id): """Get a roles info """ - path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,) + path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,) return self.client.get_json( destination=destination, @@ -813,7 +838,7 @@ class TransportLayerClient(object): content): """Update a role in a group """ - path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,) + path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,) return self.client.post_json( destination=destination, @@ -827,7 +852,7 @@ class TransportLayerClient(object): def delete_group_role(self, destination, group_id, requester_user_id, role_id): """Delete a role in a group """ - path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,) + path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,) return self.client.delete_json( destination=destination, @@ -842,11 +867,12 @@ class TransportLayerClient(object): """Update a users entry in a group """ if role_id: - path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % ( + path = _create_path( + PREFIX, "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id, ) else: - path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,) + path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,) return self.client.post_json( destination=destination, @@ -857,16 +883,32 @@ class TransportLayerClient(object): ) @log_function + def set_group_join_policy(self, destination, group_id, requester_user_id, + content): + """Sets the join policy for a group + """ + path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,) + + return self.client.put_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + data=content, + ignore_backoff=True, + ) + + @log_function def delete_group_summary_user(self, destination, group_id, requester_user_id, user_id, role_id): """Delete a users entry in a group """ if role_id: - path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % ( + path = _create_path( + PREFIX, "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id, ) else: - path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,) + path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,) return self.client.delete_json( destination=destination, @@ -889,3 +931,22 @@ class TransportLayerClient(object): data=content, ignore_backoff=True, ) + + +def _create_path(prefix, path, *args): + """Creates a path from the prefix, path template and args. Ensures that + all args are url encoded. + + Example: + + _create_path(PREFIX, "/event/%s/", event_id) + + Args: + prefix (str) + path (str): String template for the path + args: ([str]): Args to insert into path. Each arg will be url encoded + + Returns: + str + """ + return prefix + path % tuple(urllib.quote(arg, "") for arg in args) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index a66a6b0692..19d09f5422 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +25,7 @@ from synapse.http.servlet import ( ) from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import run_in_background from synapse.types import ThirdPartyInstanceID, get_domain_from_id import functools @@ -93,12 +94,6 @@ class Authenticator(object): "signatures": {}, } - if ( - self.federation_domain_whitelist is not None and - self.server_name not in self.federation_domain_whitelist - ): - raise FederationDeniedError(self.server_name) - if content is not None: json_request["content"] = content @@ -137,6 +132,12 @@ class Authenticator(object): json_request["origin"] = origin json_request["signatures"].setdefault(origin, {})[key] = sig + if ( + self.federation_domain_whitelist is not None and + origin not in self.federation_domain_whitelist + ): + raise FederationDeniedError(origin) + if not json_request["signatures"]: raise NoAuthenticationError( 401, "Missing Authorization headers", Codes.UNAUTHORIZED, @@ -151,11 +152,18 @@ class Authenticator(object): # alive retry_timings = yield self.store.get_destination_retry_timings(origin) if retry_timings and retry_timings["retry_last_ts"]: - logger.info("Marking origin %r as up", origin) - preserve_fn(self.store.set_destination_retry_timings)(origin, 0, 0) + run_in_background(self._reset_retry_timings, origin) defer.returnValue(origin) + @defer.inlineCallbacks + def _reset_retry_timings(self, origin): + try: + logger.info("Marking origin %r as up", origin) + yield self.store.set_destination_retry_timings(origin, 0, 0) + except Exception: + logger.exception("Error resetting retry timings on %s", origin) + class BaseFederationServlet(object): REQUIRE_AUTH = True @@ -802,6 +810,23 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet): defer.returnValue((200, new_content)) +class FederationGroupsJoinServlet(BaseFederationServlet): + """Attempt to join a group + """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join$" + + @defer.inlineCallbacks + def on_POST(self, origin, content, query, group_id, user_id): + if get_domain_from_id(user_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = yield self.handler.join_group( + group_id, user_id, content, + ) + + defer.returnValue((200, new_content)) + + class FederationGroupsRemoveUserServlet(BaseFederationServlet): """Leave or kick a user from the group """ @@ -1124,6 +1149,24 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): defer.returnValue((200, resp)) +class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): + """Sets whether a group is joinable without an invite or knock + """ + PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy$" + + @defer.inlineCallbacks + def on_PUT(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.set_group_join_policy( + group_id, requester_user_id, content + ) + + defer.returnValue((200, new_content)) + + FEDERATION_SERVLET_CLASSES = ( FederationSendServlet, FederationPullServlet, @@ -1163,6 +1206,7 @@ GROUP_SERVER_SERVLET_CLASSES = ( FederationGroupsInvitedUsersServlet, FederationGroupsInviteServlet, FederationGroupsAcceptInviteServlet, + FederationGroupsJoinServlet, FederationGroupsRemoveUserServlet, FederationGroupsSummaryRoomsServlet, FederationGroupsCategoriesServlet, @@ -1172,6 +1216,7 @@ GROUP_SERVER_SERVLET_CLASSES = ( FederationGroupsSummaryUsersServlet, FederationGroupsAddRoomsServlet, FederationGroupsAddRoomsConfigServlet, + FederationGroupsSettingJoinPolicyServlet, ) diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index 1fb709e6c3..6f11fa374b 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -42,7 +42,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.types import get_domain_from_id -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import run_in_background from signedjson.sign import sign_json @@ -165,31 +165,35 @@ class GroupAttestionRenewer(object): @defer.inlineCallbacks def _renew_attestation(group_id, user_id): - if not self.is_mine_id(group_id): - destination = get_domain_from_id(group_id) - elif not self.is_mine_id(user_id): - destination = get_domain_from_id(user_id) - else: - logger.warn( - "Incorrectly trying to do attestations for user: %r in %r", - user_id, group_id, + try: + if not self.is_mine_id(group_id): + destination = get_domain_from_id(group_id) + elif not self.is_mine_id(user_id): + destination = get_domain_from_id(user_id) + else: + logger.warn( + "Incorrectly trying to do attestations for user: %r in %r", + user_id, group_id, + ) + yield self.store.remove_attestation_renewal(group_id, user_id) + return + + attestation = self.attestations.create_attestation(group_id, user_id) + + yield self.transport_client.renew_group_attestation( + destination, group_id, user_id, + content={"attestation": attestation}, ) - yield self.store.remove_attestation_renewal(group_id, user_id) - return - - attestation = self.attestations.create_attestation(group_id, user_id) - yield self.transport_client.renew_group_attestation( - destination, group_id, user_id, - content={"attestation": attestation}, - ) - - yield self.store.update_attestation_renewal( - group_id, user_id, attestation - ) + yield self.store.update_attestation_renewal( + group_id, user_id, attestation + ) + except Exception: + logger.exception("Error renewing attestation of %r in %r", + user_id, group_id) for row in rows: group_id = row["group_id"] user_id = row["user_id"] - preserve_fn(_renew_attestation)(group_id, user_id) + run_in_background(_renew_attestation, group_id, user_id) diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 0b995aed70..2d95b04e0c 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -206,6 +207,28 @@ class GroupsServerHandler(object): defer.returnValue({}) @defer.inlineCallbacks + def set_group_join_policy(self, group_id, requester_user_id, content): + """Sets the group join policy. + + Currently supported policies are: + - "invite": an invite must be received and accepted in order to join. + - "open": anyone can join. + """ + yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id + ) + + join_policy = _parse_join_policy_from_contents(content) + if join_policy is None: + raise SynapseError( + 400, "No value specified for 'm.join_policy'" + ) + + yield self.store.set_group_join_policy(group_id, join_policy=join_policy) + + defer.returnValue({}) + + @defer.inlineCallbacks def get_group_categories(self, group_id, requester_user_id): """Get all categories in a group (as seen by user) """ @@ -381,9 +404,16 @@ class GroupsServerHandler(object): yield self.check_group_is_ours(group_id, requester_user_id) - group_description = yield self.store.get_group(group_id) + group = yield self.store.get_group(group_id) + + if group: + cols = [ + "name", "short_description", "long_description", + "avatar_url", "is_public", + ] + group_description = {key: group[key] for key in cols} + group_description["is_openly_joinable"] = group["join_policy"] == "open" - if group_description: defer.returnValue(group_description) else: raise SynapseError(404, "Unknown group") @@ -655,30 +685,21 @@ class GroupsServerHandler(object): raise SynapseError(502, "Unknown state returned by HS") @defer.inlineCallbacks - def accept_invite(self, group_id, requester_user_id, content): - """User tries to accept an invite to the group. + def _add_user(self, group_id, user_id, content): + """Add a user to a group based on a content dict. - This is different from them asking to join, and so should error if no - invite exists (and they're not a member of the group) + See accept_invite, join_group. """ - - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_invited = yield self.store.is_user_invited_to_local_group( - group_id, requester_user_id, - ) - if not is_invited: - raise SynapseError(403, "User not invited to group") - - if not self.hs.is_mine_id(requester_user_id): + if not self.hs.is_mine_id(user_id): local_attestation = self.attestations.create_attestation( - group_id, requester_user_id, + group_id, user_id, ) + remote_attestation = content["attestation"] yield self.attestations.verify_attestation( remote_attestation, - user_id=requester_user_id, + user_id=user_id, group_id=group_id, ) else: @@ -688,13 +709,53 @@ class GroupsServerHandler(object): is_public = _parse_visibility_from_contents(content) yield self.store.add_user_to_group( - group_id, requester_user_id, + group_id, user_id, is_admin=False, is_public=is_public, local_attestation=local_attestation, remote_attestation=remote_attestation, ) + defer.returnValue(local_attestation) + + @defer.inlineCallbacks + def accept_invite(self, group_id, requester_user_id, content): + """User tries to accept an invite to the group. + + This is different from them asking to join, and so should error if no + invite exists (and they're not a member of the group) + """ + + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + + is_invited = yield self.store.is_user_invited_to_local_group( + group_id, requester_user_id, + ) + if not is_invited: + raise SynapseError(403, "User not invited to group") + + local_attestation = yield self._add_user(group_id, requester_user_id, content) + + defer.returnValue({ + "state": "join", + "attestation": local_attestation, + }) + + @defer.inlineCallbacks + def join_group(self, group_id, requester_user_id, content): + """User tries to join the group. + + This will error if the group requires an invite/knock to join + """ + + group_info = yield self.check_group_is_ours( + group_id, requester_user_id, and_exists=True + ) + if group_info['join_policy'] != "open": + raise SynapseError(403, "Group is not publicly joinable") + + local_attestation = yield self._add_user(group_id, requester_user_id, content) + defer.returnValue({ "state": "join", "attestation": local_attestation, @@ -835,6 +896,31 @@ class GroupsServerHandler(object): }) +def _parse_join_policy_from_contents(content): + """Given a content for a request, return the specified join policy or None + """ + + join_policy_dict = content.get("m.join_policy") + if join_policy_dict: + return _parse_join_policy_dict(join_policy_dict) + else: + return None + + +def _parse_join_policy_dict(join_policy_dict): + """Given a dict for the "m.join_policy" config return the join policy specified + """ + join_policy_type = join_policy_dict.get("type") + if not join_policy_type: + return "invite" + + if join_policy_type not in ("invite", "open"): + raise SynapseError( + 400, "Synapse only supports 'invite'/'open' join rule" + ) + return join_policy_type + + def _parse_visibility_from_contents(content): """Given a content for a request parse out whether the entity should be public or not diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 3dd3fa2a27..b596f098fd 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -18,7 +18,9 @@ from twisted.internet import defer import synapse from synapse.api.constants import EventTypes from synapse.util.metrics import Measure -from synapse.util.logcontext import make_deferred_yieldable, preserve_fn +from synapse.util.logcontext import ( + make_deferred_yieldable, run_in_background, +) import logging @@ -84,11 +86,16 @@ class ApplicationServicesHandler(object): if not events: break + events_by_room = {} for event in events: + events_by_room.setdefault(event.room_id, []).append(event) + + @defer.inlineCallbacks + def handle_event(event): # Gather interested services services = yield self._get_services_for_event(event) if len(services) == 0: - continue # no services need notifying + return # no services need notifying # Do we know this user exists? If not, poke the user # query API for all services which match that user regex. @@ -104,13 +111,35 @@ class ApplicationServicesHandler(object): # Fork off pushes to these services for service in services: - preserve_fn(self.scheduler.submit_event_for_as)( - service, event - ) + self.scheduler.submit_event_for_as(service, event) - events_processed_counter.inc_by(len(events)) + @defer.inlineCallbacks + def handle_room_events(events): + for event in events: + yield handle_event(event) + + yield make_deferred_yieldable(defer.gatherResults([ + run_in_background(handle_room_events, evs) + for evs in events_by_room.itervalues() + ], consumeErrors=True)) yield self.store.set_appservice_last_pos(upper_bound) + + now = self.clock.time_msec() + ts = yield self.store.get_received_ts(events[-1].event_id) + + synapse.metrics.event_processing_positions.set( + upper_bound, "appservice_sender", + ) + + events_processed_counter.inc_by(len(events)) + + synapse.metrics.event_processing_lag.set( + now - ts, "appservice_sender", + ) + synapse.metrics.event_processing_last_ts.set( + ts, "appservice_sender", + ) finally: self.is_processing = False @@ -167,7 +196,10 @@ class ApplicationServicesHandler(object): services = yield self._get_services_for_3pn(protocol) results = yield make_deferred_yieldable(defer.DeferredList([ - preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields) + run_in_background( + self.appservice_api.query_3pe, + service, kind, protocol, fields, + ) for service in services ], consumeErrors=True)) @@ -228,11 +260,15 @@ class ApplicationServicesHandler(object): event based on the service regex. """ services = self.store.get_app_services() - interested_list = [ - s for s in services if ( - yield s.is_interested(event, self.store) - ) - ] + + # we can't use a list comprehension here. Since python 3, list + # comprehensions use a generator internally. This means you can't yield + # inside of a list comprehension anymore. + interested_list = [] + for s in services: + if (yield s.is_interested(event, self.store)): + interested_list.append(s) + defer.returnValue(interested_list) def _get_services_for_user(self, user_id): diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 40f3d24678..f7457a7082 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -155,7 +155,7 @@ class DeviceHandler(BaseHandler): try: yield self.store.delete_device(user_id, device_id) - except errors.StoreError, e: + except errors.StoreError as e: if e.code == 404: # no match pass @@ -204,7 +204,7 @@ class DeviceHandler(BaseHandler): try: yield self.store.delete_devices(user_id, device_ids) - except errors.StoreError, e: + except errors.StoreError as e: if e.code == 404: # no match pass @@ -243,7 +243,7 @@ class DeviceHandler(BaseHandler): new_display_name=content.get("display_name") ) yield self.notify_device_update(user_id, [device_id]) - except errors.StoreError, e: + except errors.StoreError as e: if e.code == 404: raise errors.NotFoundError() else: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 31b1ece13e..25aec624af 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ujson as json +import simplejson as json import logging from canonicaljson import encode_canonical_json @@ -23,7 +24,7 @@ from synapse.api.errors import ( SynapseError, CodeMessageException, FederationDeniedError, ) from synapse.types import get_domain_from_id, UserID -from synapse.util.logcontext import preserve_fn, make_deferred_yieldable +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -134,28 +135,13 @@ class E2eKeysHandler(object): if user_id in destination_query: results[user_id] = keys - except CodeMessageException as e: - failures[destination] = { - "status": e.code, "message": e.message - } - except NotRetryingDestination as e: - failures[destination] = { - "status": 503, "message": "Not ready for retry", - } - except FederationDeniedError as e: - failures[destination] = { - "status": 403, "message": "Federation Denied", - } except Exception as e: - # include ConnectionRefused and other errors - failures[destination] = { - "status": 503, "message": e.message - } + failures[destination] = _exception_to_failure(e) yield make_deferred_yieldable(defer.gatherResults([ - preserve_fn(do_remote_query)(destination) + run_in_background(do_remote_query, destination) for destination in remote_queries_not_in_cache - ])) + ], consumeErrors=True)) defer.returnValue({ "device_keys": results, "failures": failures, @@ -252,24 +238,13 @@ class E2eKeysHandler(object): for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: json_result[user_id] = keys - except CodeMessageException as e: - failures[destination] = { - "status": e.code, "message": e.message - } - except NotRetryingDestination as e: - failures[destination] = { - "status": 503, "message": "Not ready for retry", - } except Exception as e: - # include ConnectionRefused and other errors - failures[destination] = { - "status": 503, "message": e.message - } + failures[destination] = _exception_to_failure(e) yield make_deferred_yieldable(defer.gatherResults([ - preserve_fn(claim_client_keys)(destination) + run_in_background(claim_client_keys, destination) for destination in remote_queries - ])) + ], consumeErrors=True)) logger.info( "Claimed one-time-keys: %s", @@ -362,6 +337,31 @@ class E2eKeysHandler(object): ) +def _exception_to_failure(e): + if isinstance(e, CodeMessageException): + return { + "status": e.code, "message": e.message, + } + + if isinstance(e, NotRetryingDestination): + return { + "status": 503, "message": "Not ready for retry", + } + + if isinstance(e, FederationDeniedError): + return { + "status": 403, "message": "Federation Denied", + } + + # include ConnectionRefused and other errors + # + # Note that some Exceptions (notably twisted's ResponseFailed etc) don't + # give a string for e.message, which simplejson then fails to serialize. + return { + "status": 503, "message": str(e.message), + } + + def _one_time_keys_match(old_key_json, new_key): old_key = json.loads(old_key_json) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 080aca3d71..f39233d846 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -15,8 +15,16 @@ # limitations under the License. """Contains handlers for federation events.""" + +import itertools +import logging +import sys + from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json +import six +from six.moves import http_client +from twisted.internet import defer from unpaddedbase64 import decode_base64 from ._base import BaseHandler @@ -43,10 +51,6 @@ from synapse.util.retryutils import NotRetryingDestination from synapse.util.distributor import user_joined_room -from twisted.internet import defer - -import itertools -import logging logger = logging.getLogger(__name__) @@ -115,6 +119,19 @@ class FederationHandler(BaseHandler): logger.debug("Already seen pdu %s", pdu.event_id) return + # do some initial sanity-checking of the event. In particular, make + # sure it doesn't have hundreds of prev_events or auth_events, which + # could cause a huge state resolution or cascade of event fetches. + try: + self._sanity_check_event(pdu) + except SynapseError as err: + raise FederationError( + "ERROR", + err.code, + err.msg, + affected=pdu.event_id, + ) + # If we are currently in the process of joining this room, then we # queue up events for later processing. if pdu.room_id in self.room_queues: @@ -149,10 +166,6 @@ class FederationHandler(BaseHandler): auth_chain = [] - have_seen = yield self.store.have_events( - [ev for ev, _ in pdu.prev_events] - ) - fetch_state = False # Get missing pdus if necessary. @@ -168,7 +181,7 @@ class FederationHandler(BaseHandler): ) prevs = {e_id for e_id, _ in pdu.prev_events} - seen = set(have_seen.keys()) + seen = yield self.store.have_seen_events(prevs) if min_depth and pdu.depth < min_depth: # This is so that we don't notify the user about this @@ -196,8 +209,7 @@ class FederationHandler(BaseHandler): # Update the set of things we've seen after trying to # fetch the missing stuff - have_seen = yield self.store.have_events(prevs) - seen = set(have_seen.iterkeys()) + seen = yield self.store.have_seen_events(prevs) if not prevs - seen: logger.info( @@ -248,8 +260,7 @@ class FederationHandler(BaseHandler): min_depth (int): Minimum depth of events to return. """ # We recalculate seen, since it may have changed. - have_seen = yield self.store.have_events(prevs) - seen = set(have_seen.keys()) + seen = yield self.store.have_seen_events(prevs) if not prevs - seen: return @@ -361,9 +372,7 @@ class FederationHandler(BaseHandler): if auth_chain: event_ids |= {e.event_id for e in auth_chain} - seen_ids = set( - (yield self.store.have_events(event_ids)).keys() - ) + seen_ids = yield self.store.have_seen_events(event_ids) if state and auth_chain is not None: # If we have any state or auth_chain given to us by the replication @@ -527,9 +536,16 @@ class FederationHandler(BaseHandler): def backfill(self, dest, room_id, limit, extremities): """ Trigger a backfill request to `dest` for the given `room_id` - This will attempt to get more events from the remote. This may return - be successfull and still return no events if the other side has no new - events to offer. + This will attempt to get more events from the remote. If the other side + has no new events to offer, this will return an empty list. + + As the events are received, we check their signatures, and also do some + sanity-checking on them. If any of the backfilled events are invalid, + this method throws a SynapseError. + + TODO: make this more useful to distinguish failures of the remote + server from invalid events (there is probably no point in trying to + re-fetch invalid events from every other HS in the room.) """ if dest == self.server_name: raise SynapseError(400, "Can't backfill from self.") @@ -541,6 +557,16 @@ class FederationHandler(BaseHandler): extremities=extremities, ) + # ideally we'd sanity check the events here for excess prev_events etc, + # but it's hard to reject events at this point without completely + # breaking backfill in the same way that it is currently broken by + # events whose signature we cannot verify (#3121). + # + # So for now we accept the events anyway. #3124 tracks this. + # + # for ev in events: + # self._sanity_check_event(ev) + # Don't bother processing events we already have. seen_events = yield self.store.have_events_in_timeline( set(e.event_id for e in events) @@ -613,7 +639,8 @@ class FederationHandler(BaseHandler): results = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - logcontext.preserve_fn(self.replication_layer.get_pdu)( + logcontext.run_in_background( + self.replication_layer.get_pdu, [dest], event_id, outlier=True, @@ -633,7 +660,7 @@ class FederationHandler(BaseHandler): failed_to_fetch = missing_auth - set(auth_events) - seen_events = yield self.store.have_events( + seen_events = yield self.store.have_seen_events( set(auth_events.keys()) | set(state_events.keys()) ) @@ -843,6 +870,38 @@ class FederationHandler(BaseHandler): defer.returnValue(False) + def _sanity_check_event(self, ev): + """ + Do some early sanity checks of a received event + + In particular, checks it doesn't have an excessive number of + prev_events or auth_events, which could cause a huge state resolution + or cascade of event fetches. + + Args: + ev (synapse.events.EventBase): event to be checked + + Returns: None + + Raises: + SynapseError if the event does not pass muster + """ + if len(ev.prev_events) > 20: + logger.warn("Rejecting event %s which has %i prev_events", + ev.event_id, len(ev.prev_events)) + raise SynapseError( + http_client.BAD_REQUEST, + "Too many prev_events", + ) + + if len(ev.auth_events) > 10: + logger.warn("Rejecting event %s which has %i auth_events", + ev.event_id, len(ev.auth_events)) + raise SynapseError( + http_client.BAD_REQUEST, + "Too many auth_events", + ) + @defer.inlineCallbacks def send_invite(self, target_host, event): """ Sends the invite to the remote server for signing. @@ -967,7 +1026,7 @@ class FederationHandler(BaseHandler): # lots of requests for missing prev_events which we do actually # have. Hence we fire off the deferred, but don't wait for it. - logcontext.preserve_fn(self._handle_queued_pdus)(room_queue) + logcontext.run_in_background(self._handle_queued_pdus, room_queue) defer.returnValue(True) @@ -1457,18 +1516,21 @@ class FederationHandler(BaseHandler): backfilled=backfilled, ) except: # noqa: E722, as we reraise the exception this is fine. - # Ensure that we actually remove the entries in the push actions - # staging area - logcontext.preserve_fn( - self.store.remove_push_actions_from_staging - )(event.event_id) - raise + tp, value, tb = sys.exc_info() + + logcontext.run_in_background( + self.store.remove_push_actions_from_staging, + event.event_id, + ) + + six.reraise(tp, value, tb) if not backfilled: # this intentionally does not yield: we don't care about the result # and don't need to wait for it. - logcontext.preserve_fn(self.pusher_pool.on_new_notifications)( - event_stream_id, max_stream_id + logcontext.run_in_background( + self.pusher_pool.on_new_notifications, + event_stream_id, max_stream_id, ) defer.returnValue((context, event_stream_id, max_stream_id)) @@ -1482,7 +1544,8 @@ class FederationHandler(BaseHandler): """ contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( [ - logcontext.preserve_fn(self._prep_event)( + logcontext.run_in_background( + self._prep_event, origin, ev_info["event"], state=ev_info.get("state"), @@ -1736,7 +1799,8 @@ class FederationHandler(BaseHandler): event_key = None if event_auth_events - current_state: - have_events = yield self.store.have_events( + # TODO: can we use store.have_seen_events here instead? + have_events = yield self.store.get_seen_events_with_rejections( event_auth_events - current_state ) else: @@ -1759,12 +1823,12 @@ class FederationHandler(BaseHandler): origin, event.room_id, event.event_id ) - seen_remotes = yield self.store.have_events( + seen_remotes = yield self.store.have_seen_events( [e.event_id for e in remote_auth_chain] ) for e in remote_auth_chain: - if e.event_id in seen_remotes.keys(): + if e.event_id in seen_remotes: continue if e.event_id == event.event_id: @@ -1791,7 +1855,7 @@ class FederationHandler(BaseHandler): except AuthError: pass - have_events = yield self.store.have_events( + have_events = yield self.store.get_seen_events_with_rejections( [e_id for e_id, _ in event.auth_events] ) seen_events = set(have_events.keys()) @@ -1810,7 +1874,8 @@ class FederationHandler(BaseHandler): different_events = yield logcontext.make_deferred_yieldable( defer.gatherResults([ - logcontext.preserve_fn(self.store.get_event)( + logcontext.run_in_background( + self.store.get_event, d, allow_none=True, allow_rejected=False, @@ -1876,13 +1941,13 @@ class FederationHandler(BaseHandler): local_auth_chain, ) - seen_remotes = yield self.store.have_events( + seen_remotes = yield self.store.have_seen_events( [e.event_id for e in result["auth_chain"]] ) # 3. Process any remote auth chain events we haven't seen. for ev in result["auth_chain"]: - if ev.event_id in seen_remotes.keys(): + if ev.event_id in seen_remotes: continue if ev.event_id == event.event_id: diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index e4d0cc8b02..977993e7d4 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -90,6 +91,8 @@ class GroupsLocalHandler(object): get_group_role = _create_rerouter("get_group_role") get_group_roles = _create_rerouter("get_group_roles") + set_group_join_policy = _create_rerouter("set_group_join_policy") + @defer.inlineCallbacks def get_group_summary(self, group_id, requester_user_id): """Get the group summary for a group. @@ -226,7 +229,45 @@ class GroupsLocalHandler(object): def join_group(self, group_id, user_id, content): """Request to join a group """ - raise NotImplementedError() # TODO + if self.is_mine_id(group_id): + yield self.groups_server_handler.join_group( + group_id, user_id, content + ) + local_attestation = None + remote_attestation = None + else: + local_attestation = self.attestations.create_attestation(group_id, user_id) + content["attestation"] = local_attestation + + res = yield self.transport_client.join_group( + get_domain_from_id(group_id), group_id, user_id, content, + ) + + remote_attestation = res["attestation"] + + yield self.attestations.verify_attestation( + remote_attestation, + group_id=group_id, + user_id=user_id, + server_name=get_domain_from_id(group_id), + ) + + # TODO: Check that the group is public and we're being added publically + is_publicised = content.get("publicise", False) + + token = yield self.store.register_user_group_membership( + group_id, user_id, + membership="join", + is_admin=False, + local_attestation=local_attestation, + remote_attestation=remote_attestation, + is_publicised=is_publicised, + ) + self.notifier.on_new_event( + "groups_key", token, users=[user_id], + ) + + defer.returnValue({}) @defer.inlineCallbacks def accept_invite(self, group_id, user_id, content): diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 9efcdff1d6..91a0898860 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -15,6 +15,11 @@ # limitations under the License. """Utilities for interacting with Identity Servers""" + +import logging + +import simplejson as json + from twisted.internet import defer from synapse.api.errors import ( @@ -24,9 +29,6 @@ from ._base import BaseHandler from synapse.util.async import run_on_reactor from synapse.api.errors import SynapseError, Codes -import json -import logging - logger = logging.getLogger(__name__) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index c5267b4b84..cd33a86599 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -27,7 +27,7 @@ from synapse.types import ( from synapse.util import unwrapFirstError from synapse.util.async import concurrently_execute from synapse.util.caches.snapshot_cache import SnapshotCache -from synapse.util.logcontext import make_deferred_yieldable, preserve_fn +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -166,7 +166,8 @@ class InitialSyncHandler(BaseHandler): (messages, token), current_state = yield make_deferred_yieldable( defer.gatherResults( [ - preserve_fn(self.store.get_recent_events_for_room)( + run_in_background( + self.store.get_recent_events_for_room, event.room_id, limit=limit, end_token=room_end_token, @@ -391,9 +392,10 @@ class InitialSyncHandler(BaseHandler): presence, receipts, (messages, token) = yield defer.gatherResults( [ - preserve_fn(get_presence)(), - preserve_fn(get_receipts)(), - preserve_fn(self.store.get_recent_events_for_room)( + run_in_background(get_presence), + run_in_background(get_receipts), + run_in_background( + self.store.get_recent_events_for_room, room_id, limit=limit, end_token=now_token.room_key, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4f97c8db79..b793fc4df7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -13,10 +13,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import simplejson +import sys + +from canonicaljson import encode_canonical_json +import six from twisted.internet import defer, reactor from twisted.python.failure import Failure -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, MAX_DEPTH from synapse.api.errors import AuthError, Codes, SynapseError from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event @@ -25,21 +31,15 @@ from synapse.types import ( UserID, RoomAlias, RoomStreamToken, ) from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter -from synapse.util.logcontext import preserve_fn, run_in_background +from synapse.util.logcontext import run_in_background from synapse.util.metrics import measure_func -from synapse.util.frozenutils import unfreeze +from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client from synapse.replication.http.send_event import send_event_to_master from ._base import BaseHandler -from canonicaljson import encode_canonical_json - -import logging -import random -import ujson - logger = logging.getLogger(__name__) @@ -433,7 +433,7 @@ class EventCreationHandler(object): @defer.inlineCallbacks def create_event(self, requester, event_dict, token_id=None, txn_id=None, - prev_event_ids=None): + prev_events_and_hashes=None): """ Given a dict from a client, create a new event. @@ -447,47 +447,52 @@ class EventCreationHandler(object): event_dict (dict): An entire event token_id (str) txn_id (str) - prev_event_ids (list): The prev event ids to use when creating the event + + prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + the forward extremities to use as the prev_events for the + new event. For each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + + If None, they will be requested from the database. Returns: Tuple of created event (FrozenEvent), Context """ builder = self.event_builder_factory.new(event_dict) - with (yield self.limiter.queue(builder.room_id)): - self.validator.validate_new(builder) - - if builder.type == EventTypes.Member: - membership = builder.content.get("membership", None) - target = UserID.from_string(builder.state_key) - - if membership in {Membership.JOIN, Membership.INVITE}: - # If event doesn't include a display name, add one. - profile = self.profile_handler - content = builder.content - - try: - if "displayname" not in content: - content["displayname"] = yield profile.get_displayname(target) - if "avatar_url" not in content: - content["avatar_url"] = yield profile.get_avatar_url(target) - except Exception as e: - logger.info( - "Failed to get profile information for %r: %s", - target, e - ) + self.validator.validate_new(builder) + + if builder.type == EventTypes.Member: + membership = builder.content.get("membership", None) + target = UserID.from_string(builder.state_key) + + if membership in {Membership.JOIN, Membership.INVITE}: + # If event doesn't include a display name, add one. + profile = self.profile_handler + content = builder.content + + try: + if "displayname" not in content: + content["displayname"] = yield profile.get_displayname(target) + if "avatar_url" not in content: + content["avatar_url"] = yield profile.get_avatar_url(target) + except Exception as e: + logger.info( + "Failed to get profile information for %r: %s", + target, e + ) - if token_id is not None: - builder.internal_metadata.token_id = token_id + if token_id is not None: + builder.internal_metadata.token_id = token_id - if txn_id is not None: - builder.internal_metadata.txn_id = txn_id + if txn_id is not None: + builder.internal_metadata.txn_id = txn_id - event, context = yield self.create_new_client_event( - builder=builder, - requester=requester, - prev_event_ids=prev_event_ids, - ) + event, context = yield self.create_new_client_event( + builder=builder, + requester=requester, + prev_events_and_hashes=prev_events_and_hashes, + ) defer.returnValue((event, context)) @@ -557,64 +562,80 @@ class EventCreationHandler(object): See self.create_event and self.send_nonmember_event. """ - event, context = yield self.create_event( - requester, - event_dict, - token_id=requester.access_token_id, - txn_id=txn_id - ) - spam_error = self.spam_checker.check_event_for_spam(event) - if spam_error: - if not isinstance(spam_error, basestring): - spam_error = "Spam is not permitted here" - raise SynapseError( - 403, spam_error, Codes.FORBIDDEN + # We limit the number of concurrent event sends in a room so that we + # don't fork the DAG too much. If we don't limit then we can end up in + # a situation where event persistence can't keep up, causing + # extremities to pile up, which in turn leads to state resolution + # taking longer. + with (yield self.limiter.queue(event_dict["room_id"])): + event, context = yield self.create_event( + requester, + event_dict, + token_id=requester.access_token_id, + txn_id=txn_id ) - yield self.send_nonmember_event( - requester, - event, - context, - ratelimit=ratelimit, - ) + spam_error = self.spam_checker.check_event_for_spam(event) + if spam_error: + if not isinstance(spam_error, basestring): + spam_error = "Spam is not permitted here" + raise SynapseError( + 403, spam_error, Codes.FORBIDDEN + ) + + yield self.send_nonmember_event( + requester, + event, + context, + ratelimit=ratelimit, + ) defer.returnValue(event) @measure_func("create_new_client_event") @defer.inlineCallbacks - def create_new_client_event(self, builder, requester=None, prev_event_ids=None): - if prev_event_ids: - prev_events = yield self.store.add_event_hashes(prev_event_ids) - prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids) - depth = prev_max_depth + 1 - else: - latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( - builder.room_id, - ) + def create_new_client_event(self, builder, requester=None, + prev_events_and_hashes=None): + """Create a new event for a local client - # We want to limit the max number of prev events we point to in our - # new event - if len(latest_ret) > 10: - # Sort by reverse depth, so we point to the most recent. - latest_ret.sort(key=lambda a: -a[2]) - new_latest_ret = latest_ret[:5] - - # We also randomly point to some of the older events, to make - # sure that we don't completely ignore the older events. - if latest_ret[5:]: - sample_size = min(5, len(latest_ret[5:])) - new_latest_ret.extend(random.sample(latest_ret[5:], sample_size)) - latest_ret = new_latest_ret - - if latest_ret: - depth = max([d for _, _, d in latest_ret]) + 1 - else: - depth = 1 + Args: + builder (EventBuilder): + + requester (synapse.types.Requester|None): + + prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + the forward extremities to use as the prev_events for the + new event. For each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + + If None, they will be requested from the database. + + Returns: + Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)] + """ + + if prev_events_and_hashes is not None: + assert len(prev_events_and_hashes) <= 10, \ + "Attempting to create an event with %i prev_events" % ( + len(prev_events_and_hashes), + ) + else: + prev_events_and_hashes = \ + yield self.store.get_prev_events_for_room(builder.room_id) + + if prev_events_and_hashes: + depth = max([d for _, _, d in prev_events_and_hashes]) + 1 + # we cap depth of generated events, to ensure that they are not + # rejected by other servers (and so that they can be persisted in + # the db) + depth = min(depth, MAX_DEPTH) + else: + depth = 1 - prev_events = [ - (event_id, prev_hashes) - for event_id, prev_hashes, _ in latest_ret - ] + prev_events = [ + (event_id, prev_hashes) + for event_id, prev_hashes, _ in prev_events_and_hashes + ] builder.prev_events = prev_events builder.depth = depth @@ -678,8 +699,8 @@ class EventCreationHandler(object): # Ensure that we can round trip before trying to persist in db try: - dump = ujson.dumps(unfreeze(event.content)) - ujson.loads(dump) + dump = frozendict_json_encoder.encode(event.content) + simplejson.loads(dump) except Exception: logger.exception("Failed to encode content: %r", event.content) raise @@ -713,8 +734,14 @@ class EventCreationHandler(object): except: # noqa: E722, as we reraise the exception this is fine. # Ensure that we actually remove the entries in the push actions # staging area, if we calculated them. - preserve_fn(self.store.remove_push_actions_from_staging)(event.event_id) - raise + tp, value, tb = sys.exc_info() + + run_in_background( + self.store.remove_push_actions_from_staging, + event.event_id, + ) + + six.reraise(tp, value, tb) @defer.inlineCallbacks def persist_and_notify_client_event( @@ -834,22 +861,33 @@ class EventCreationHandler(object): # this intentionally does not yield: we don't care about the result # and don't need to wait for it. - preserve_fn(self.pusher_pool.on_new_notifications)( + run_in_background( + self.pusher_pool.on_new_notifications, event_stream_id, max_stream_id ) @defer.inlineCallbacks def _notify(): yield run_on_reactor() - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users - ) + try: + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=extra_users + ) + except Exception: + logger.exception("Error notifying about new room event") - preserve_fn(_notify)() + run_in_background(_notify) if event.type == EventTypes.Message: - presence = self.hs.get_presence_handler() # We don't want to block sending messages on any presence code. This # matters as sometimes presence code can take a while. - preserve_fn(presence.bump_presence_active_time)(requester.user) + run_in_background(self._bump_active_time, requester.user) + + @defer.inlineCallbacks + def _bump_active_time(self, user): + try: + presence = self.hs.get_presence_handler() + yield presence.bump_presence_active_time(user) + except Exception: + logger.exception("Error bumping presence active time") diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index a5e501897c..585f3e4da2 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -31,7 +31,7 @@ from synapse.storage.presence import UserPresenceState from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.async import Linearizer -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import run_in_background from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -255,6 +255,14 @@ class PresenceHandler(object): logger.info("Finished _persist_unpersisted_changes") @defer.inlineCallbacks + def _update_states_and_catch_exception(self, new_states): + try: + res = yield self._update_states(new_states) + defer.returnValue(res) + except Exception: + logger.exception("Error updating presence") + + @defer.inlineCallbacks def _update_states(self, new_states): """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state @@ -364,7 +372,7 @@ class PresenceHandler(object): now=now, ) - preserve_fn(self._update_states)(changes) + run_in_background(self._update_states_and_catch_exception, changes) except Exception: logger.exception("Exception in _handle_timeouts loop") @@ -422,20 +430,23 @@ class PresenceHandler(object): @defer.inlineCallbacks def _end(): - if affect_presence: + try: self.user_to_num_current_syncs[user_id] -= 1 prev_state = yield self.current_state_for_user(user_id) yield self._update_states([prev_state.copy_and_replace( last_user_sync_ts=self.clock.time_msec(), )]) + except Exception: + logger.exception("Error updating presence after sync") @contextmanager def _user_syncing(): try: yield finally: - preserve_fn(_end)() + if affect_presence: + run_in_background(_end) defer.returnValue(_user_syncing()) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 3f215c2b4e..2e0672161c 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -135,37 +135,40 @@ class ReceiptsHandler(BaseHandler): """Given a list of receipts, works out which remote servers should be poked and pokes them. """ - # TODO: Some of this stuff should be coallesced. - for receipt in receipts: - room_id = receipt["room_id"] - receipt_type = receipt["receipt_type"] - user_id = receipt["user_id"] - event_ids = receipt["event_ids"] - data = receipt["data"] - - users = yield self.state.get_current_user_in_room(room_id) - remotedomains = set(get_domain_from_id(u) for u in users) - remotedomains = remotedomains.copy() - remotedomains.discard(self.server_name) - - logger.debug("Sending receipt to: %r", remotedomains) - - for domain in remotedomains: - self.federation.send_edu( - destination=domain, - edu_type="m.receipt", - content={ - room_id: { - receipt_type: { - user_id: { - "event_ids": event_ids, - "data": data, + try: + # TODO: Some of this stuff should be coallesced. + for receipt in receipts: + room_id = receipt["room_id"] + receipt_type = receipt["receipt_type"] + user_id = receipt["user_id"] + event_ids = receipt["event_ids"] + data = receipt["data"] + + users = yield self.state.get_current_user_in_room(room_id) + remotedomains = set(get_domain_from_id(u) for u in users) + remotedomains = remotedomains.copy() + remotedomains.discard(self.server_name) + + logger.debug("Sending receipt to: %r", remotedomains) + + for domain in remotedomains: + self.federation.send_edu( + destination=domain, + edu_type="m.receipt", + content={ + room_id: { + receipt_type: { + user_id: { + "event_ids": event_ids, + "data": data, + } } - } + }, }, - }, - key=(room_id, receipt_type, user_id), - ) + key=(room_id, receipt_type, user_id), + ) + except Exception: + logger.exception("Error pushing receipts to remote servers") @defer.inlineCallbacks def get_receipts_for_room(self, room_id, to_key): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ed5939880a..f83c6b3cf8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -23,8 +23,8 @@ from synapse.api.errors import ( ) from synapse.http.client import CaptchaServerHttpClient from synapse import types -from synapse.types import UserID -from synapse.util.async import run_on_reactor +from synapse.types import UserID, create_requester, RoomID, RoomAlias +from synapse.util.async import run_on_reactor, Linearizer from synapse.util.threepids import check_3pid_allowed from ._base import BaseHandler @@ -46,6 +46,10 @@ class RegistrationHandler(BaseHandler): self.macaroon_gen = hs.get_macaroon_generator() + self._generate_user_id_linearizer = Linearizer( + name="_generate_user_id_linearizer", + ) + @defer.inlineCallbacks def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): @@ -201,10 +205,17 @@ class RegistrationHandler(BaseHandler): token = None attempts += 1 + # auto-join the user to any rooms we're supposed to dump them into + fake_requester = create_requester(user_id) + for r in self.hs.config.auto_join_rooms: + try: + yield self._join_user_to_room(fake_requester, r) + except Exception as e: + logger.error("Failed to join new user to %r: %r", r, e) + # We used to generate default identicons here, but nowadays # we want clients to generate their own as part of their branding # rather than there being consistent matrix-wide ones, so we don't. - defer.returnValue((user_id, token)) @defer.inlineCallbacks @@ -345,9 +356,11 @@ class RegistrationHandler(BaseHandler): @defer.inlineCallbacks def _generate_user_id(self, reseed=False): if reseed or self._next_generated_user_id is None: - self._next_generated_user_id = ( - yield self.store.find_next_generated_user_id_localpart() - ) + with (yield self._generate_user_id_linearizer.queue(())): + if reseed or self._next_generated_user_id is None: + self._next_generated_user_id = ( + yield self.store.find_next_generated_user_id_localpart() + ) id = self._next_generated_user_id self._next_generated_user_id += 1 @@ -477,3 +490,28 @@ class RegistrationHandler(BaseHandler): ) defer.returnValue((user_id, access_token)) + + @defer.inlineCallbacks + def _join_user_to_room(self, requester, room_identifier): + room_id = None + room_member_handler = self.hs.get_room_member_handler() + if RoomID.is_valid(room_identifier): + room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, remote_room_hosts = ( + yield room_member_handler.lookup_room_alias(room_alias) + ) + room_id = room_id.to_string() + else: + raise SynapseError(400, "%s was not legal room ID or room alias" % ( + room_identifier, + )) + + yield room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + remote_room_hosts=remote_room_hosts, + action="join", + ) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 5d81f59b44..5757bb7f8a 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -15,12 +15,13 @@ from twisted.internet import defer +from six.moves import range + from ._base import BaseHandler from synapse.api.constants import ( EventTypes, JoinRules, ) -from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.async import concurrently_execute from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.response_cache import ResponseCache @@ -44,8 +45,9 @@ EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler(BaseHandler): def __init__(self, hs): super(RoomListHandler, self).__init__(hs) - self.response_cache = ResponseCache(hs) - self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000) + self.response_cache = ResponseCache(hs, "room_list") + self.remote_response_cache = ResponseCache(hs, "remote_room_list", + timeout_ms=30 * 1000) def get_local_public_room_list(self, limit=None, since_token=None, search_filter=None, @@ -77,18 +79,11 @@ class RoomListHandler(BaseHandler): ) key = (limit, since_token, network_tuple) - result = self.response_cache.get(key) - if not result: - logger.info("No cached result, calculating one.") - result = self.response_cache.set( - key, - preserve_fn(self._get_public_room_list)( - limit, since_token, network_tuple=network_tuple - ) - ) - else: - logger.info("Using cached deferred result.") - return make_deferred_yieldable(result) + return self.response_cache.wrap( + key, + self._get_public_room_list, + limit, since_token, network_tuple=network_tuple, + ) @defer.inlineCallbacks def _get_public_room_list(self, limit=None, since_token=None, @@ -207,7 +202,7 @@ class RoomListHandler(BaseHandler): step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1 chunk = [] - for i in xrange(0, len(rooms_to_scan), step): + for i in range(0, len(rooms_to_scan), step): batch = rooms_to_scan[i:i + step] logger.info("Processing %i rooms for result", len(batch)) yield concurrently_execute( @@ -422,18 +417,14 @@ class RoomListHandler(BaseHandler): server_name, limit, since_token, include_all_networks, third_party_instance_id, ) - result = self.remote_response_cache.get(key) - if not result: - result = self.remote_response_cache.set( - key, - repl_layer.get_public_rooms( - server_name, limit=limit, since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) - ) - return result + return self.remote_response_cache.wrap( + key, + repl_layer.get_public_rooms, + server_name, limit=limit, since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) class RoomListNextBatch(namedtuple("RoomListNextBatch", ( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 9977be8831..714583f1d5 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -149,7 +149,7 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _local_membership_update( self, requester, target, room_id, membership, - prev_event_ids, + prev_events_and_hashes, txn_id=None, ratelimit=True, content=None, @@ -175,7 +175,7 @@ class RoomMemberHandler(object): }, token_id=requester.access_token_id, txn_id=txn_id, - prev_event_ids=prev_event_ids, + prev_events_and_hashes=prev_events_and_hashes, ) # Check if this event matches the previous membership event for the user. @@ -314,7 +314,12 @@ class RoomMemberHandler(object): 403, "Invites have been disabled on this server", ) - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + prev_events_and_hashes = yield self.store.get_prev_events_for_room( + room_id, + ) + latest_event_ids = ( + event_id for (event_id, _, _) in prev_events_and_hashes + ) current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, ) @@ -403,7 +408,7 @@ class RoomMemberHandler(object): membership=effective_membership_state, txn_id=txn_id, ratelimit=ratelimit, - prev_event_ids=latest_event_ids, + prev_events_and_hashes=prev_events_and_hashes, content=content, ) defer.returnValue(res) @@ -852,6 +857,14 @@ class RoomMemberMasterHandler(RoomMemberHandler): def _remote_join(self, requester, remote_room_hosts, room_id, user, content): """Implements RoomMemberHandler._remote_join """ + # filter ourselves out of remote_room_hosts: do_invite_join ignores it + # and if it is the only entry we'd like to return a 404 rather than a + # 500. + + remote_room_hosts = [ + host for host in remote_room_hosts if host != self.hs.hostname + ] + if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0f713ce038..b52e4c2aff 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -15,7 +15,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.util.async import concurrently_execute -from synapse.util.logcontext import LoggingContext, make_deferred_yieldable, preserve_fn +from synapse.util.logcontext import LoggingContext from synapse.util.metrics import Measure, measure_func from synapse.util.caches.response_cache import ResponseCache from synapse.push.clientformat import format_push_rules_for_user @@ -52,6 +52,7 @@ class TimelineBatch(collections.namedtuple("TimelineBatch", [ to tell if room needs to be part of the sync result. """ return bool(self.events) + __bool__ = __nonzero__ # python3 class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ @@ -76,6 +77,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ # nb the notification count does not, er, count: if there's nothing # else in the result, we don't need to send it. ) + __bool__ = __nonzero__ # python3 class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [ @@ -95,6 +97,7 @@ class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [ or self.state or self.account_data ) + __bool__ = __nonzero__ # python3 class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ @@ -106,6 +109,7 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ def __nonzero__(self): """Invited rooms should always be reported to the client""" return True + __bool__ = __nonzero__ # python3 class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [ @@ -117,6 +121,7 @@ class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [ def __nonzero__(self): return bool(self.join or self.invite or self.leave) + __bool__ = __nonzero__ # python3 class DeviceLists(collections.namedtuple("DeviceLists", [ @@ -127,6 +132,7 @@ class DeviceLists(collections.namedtuple("DeviceLists", [ def __nonzero__(self): return bool(self.changed or self.left) + __bool__ = __nonzero__ # python3 class SyncResult(collections.namedtuple("SyncResult", [ @@ -159,6 +165,7 @@ class SyncResult(collections.namedtuple("SyncResult", [ self.device_lists or self.groups ) + __bool__ = __nonzero__ # python3 class SyncHandler(object): @@ -169,7 +176,7 @@ class SyncHandler(object): self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() - self.response_cache = ResponseCache(hs) + self.response_cache = ResponseCache(hs, "sync") self.state = hs.get_state_handler() def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, @@ -180,15 +187,11 @@ class SyncHandler(object): Returns: A Deferred SyncResult. """ - result = self.response_cache.get(sync_config.request_key) - if not result: - result = self.response_cache.set( - sync_config.request_key, - preserve_fn(self._wait_for_sync_for_user)( - sync_config, since_token, timeout, full_state - ) - ) - return make_deferred_yieldable(result) + return self.response_cache.wrap( + sync_config.request_key, + self._wait_for_sync_for_user, + sync_config, since_token, timeout, full_state, + ) @defer.inlineCallbacks def _wait_for_sync_for_user(self, sync_config, since_token, timeout, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 77c0cf146f..5d9736e88f 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import run_in_background from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer from synapse.types import UserID, get_domain_from_id @@ -97,7 +97,8 @@ class TypingHandler(object): if self.hs.is_mine_id(member.user_id): last_fed_poke = self._member_last_federation_poke.get(member, None) if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: - preserve_fn(self._push_remote)( + run_in_background( + self._push_remote, member=member, typing=True ) @@ -196,7 +197,7 @@ class TypingHandler(object): def _push_update(self, member, typing): if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. - preserve_fn(self._push_remote)(member, typing) + run_in_background(self._push_remote, member, typing) self._push_update_local( member=member, @@ -205,28 +206,31 @@ class TypingHandler(object): @defer.inlineCallbacks def _push_remote(self, member, typing): - users = yield self.state.get_current_user_in_room(member.room_id) - self._member_last_federation_poke[member] = self.clock.time_msec() + try: + users = yield self.state.get_current_user_in_room(member.room_id) + self._member_last_federation_poke[member] = self.clock.time_msec() - now = self.clock.time_msec() - self.wheel_timer.insert( - now=now, - obj=member, - then=now + FEDERATION_PING_INTERVAL, - ) + now = self.clock.time_msec() + self.wheel_timer.insert( + now=now, + obj=member, + then=now + FEDERATION_PING_INTERVAL, + ) - for domain in set(get_domain_from_id(u) for u in users): - if domain != self.server_name: - self.federation.send_edu( - destination=domain, - edu_type="m.typing", - content={ - "room_id": member.room_id, - "user_id": member.user_id, - "typing": typing, - }, - key=member, - ) + for domain in set(get_domain_from_id(u) for u in users): + if domain != self.server_name: + self.federation.send_edu( + destination=domain, + edu_type="m.typing", + content={ + "room_id": member.room_id, + "user_id": member.user_id, + "typing": typing, + }, + key=member, + ) + except Exception: + logger.exception("Error pushing typing notif to remotes") @defer.inlineCallbacks def _recv_edu(self, origin, content): diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index bfebb0f644..054372e179 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,3 +13,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet.defer import CancelledError +from twisted.python import failure + +from synapse.api.errors import SynapseError + + +class RequestTimedOutError(SynapseError): + """Exception representing timeout of an outbound request""" + def __init__(self): + super(RequestTimedOutError, self).__init__(504, "Timed out") + + +def cancelled_to_request_timed_out_error(value, timeout): + """Turns CancelledErrors into RequestTimedOutErrors. + + For use with async.add_timeout_to_deferred + """ + if isinstance(value, failure.Failure): + value.trap(CancelledError) + raise RequestTimedOutError() + return value diff --git a/synapse/http/client.py b/synapse/http/client.py index f3e4973c2e..70a19d9b74 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,9 +19,10 @@ from OpenSSL.SSL import VERIFY_NONE from synapse.api.errors import ( CodeMessageException, MatrixCodeMessageException, SynapseError, Codes, ) +from synapse.http import cancelled_to_request_timed_out_error +from synapse.util.async import add_timeout_to_deferred from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.logcontext import make_deferred_yieldable -from synapse.util import logcontext import synapse.metrics from synapse.http.endpoint import SpiderEndpoint @@ -38,7 +40,7 @@ from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers from twisted.web._newclient import ResponseDone -from StringIO import StringIO +from six import StringIO import simplejson as json import logging @@ -95,21 +97,17 @@ class SimpleHttpClient(object): # counters to it outgoing_requests_counter.inc(method) - def send_request(): + logger.info("Sending request %s %s", method, uri) + + try: request_deferred = self.agent.request( method, uri, *args, **kwargs ) - - return self.clock.time_bound_deferred( + add_timeout_to_deferred( request_deferred, - time_out=60, + 60, cancelled_to_request_timed_out_error, ) - - logger.info("Sending request %s %s", method, uri) - - try: - with logcontext.PreserveLoggingContext(): - response = yield send_request() + response = yield make_deferred_yieldable(request_deferred) incoming_responses_counter.inc(method, response.code) logger.info( @@ -509,7 +507,7 @@ class SpiderHttpClient(SimpleHttpClient): reactor, SpiderEndpointFactory(hs) ) - ), [('gzip', GzipDecoder)] + ), [(b'gzip', GzipDecoder)] ) # We could look like Chrome: # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko) diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 87639b9151..87a482650d 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import socket - from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet import defer, reactor from twisted.internet.error import ConnectError @@ -33,7 +31,7 @@ SERVER_CACHE = {} # our record of an individual server which can be tried to reach a destination. # -# "host" is actually a dotted-quad or ipv6 address string. Except when there's +# "host" is the hostname acquired from the SRV record. Except when there's # no SRV record, in which case it is the original hostname. _Server = collections.namedtuple( "_Server", "priority weight host port expires" @@ -117,10 +115,15 @@ class _WrappedConnection(object): if time.time() - self.last_request >= 2.5 * 60: self.abort() # Abort the underlying TLS connection. The abort() method calls - # loseConnection() on the underlying TLS connection which tries to + # loseConnection() on the TLS connection which tries to # shutdown the connection cleanly. We call abortConnection() - # since that will promptly close the underlying TCP connection. - self.transport.abortConnection() + # since that will promptly close the TLS connection. + # + # In Twisted >18.4; the TLS connection will be None if it has closed + # which will make abortConnection() throw. Check that the TLS connection + # is not None before trying to close it. + if self.transport.getHandle() is not None: + self.transport.abortConnection() def request(self, request): self.last_request = time.time() @@ -288,7 +291,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t if (len(answers) == 1 and answers[0].type == dns.SRV and answers[0].payload - and answers[0].payload.target == dns.Name('.')): + and answers[0].payload.target == dns.Name(b'.')): raise ConnectError("Service %s unavailable" % service_name) for answer in answers: @@ -297,20 +300,13 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t payload = answer.payload - hosts = yield _get_hosts_for_srv_record( - dns_client, str(payload.target) - ) - - for (ip, ttl) in hosts: - host_ttl = min(answer.ttl, ttl) - - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight), - expires=int(clock.time()) + host_ttl, - )) + servers.append(_Server( + host=str(payload.target), + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(clock.time()) + answer.ttl, + )) servers.sort() cache[service_name] = list(servers) @@ -328,81 +324,3 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t raise e defer.returnValue(servers) - - -@defer.inlineCallbacks -def _get_hosts_for_srv_record(dns_client, host): - """Look up each of the hosts in a SRV record - - Args: - dns_client (twisted.names.dns.IResolver): - host (basestring): host to look up - - Returns: - Deferred[list[(str, int)]]: a list of (host, ttl) pairs - - """ - ip4_servers = [] - ip6_servers = [] - - def cb(res): - # lookupAddress and lookupIP6Address return a three-tuple - # giving the answer, authority, and additional sections of the - # response. - # - # we only care about the answers. - - return res[0] - - def eb(res, record_type): - if res.check(DNSNameError): - return [] - logger.warn("Error looking up %s for %s: %s", record_type, host, res) - return res - - # no logcontexts here, so we can safely fire these off and gatherResults - d1 = dns_client.lookupAddress(host).addCallbacks( - cb, eb, errbackArgs=("A", )) - d2 = dns_client.lookupIPV6Address(host).addCallbacks( - cb, eb, errbackArgs=("AAAA", )) - results = yield defer.DeferredList( - [d1, d2], consumeErrors=True) - - # if all of the lookups failed, raise an exception rather than blowing out - # the cache with an empty result. - if results and all(s == defer.FAILURE for (s, _) in results): - defer.returnValue(results[0][1]) - - for (success, result) in results: - if success == defer.FAILURE: - continue - - for answer in result: - if not answer.payload: - continue - - try: - if answer.type == dns.A: - ip = answer.payload.dottedQuad() - ip4_servers.append((ip, answer.ttl)) - elif answer.type == dns.AAAA: - ip = socket.inet_ntop( - socket.AF_INET6, answer.payload.address, - ) - ip6_servers.append((ip, answer.ttl)) - else: - # the most likely candidate here is a CNAME record. - # rfc2782 says srvs may not point to aliases. - logger.warn( - "Ignoring unexpected DNS record type %s for %s", - answer.type, host, - ) - continue - except Exception as e: - logger.warn("Ignoring invalid DNS response for %s: %s", - host, e) - continue - - # keep the ipv4 results before the ipv6 results, mostly to match historical - # behaviour. - defer.returnValue(ip4_servers + ip6_servers) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 9145405cb0..4b2b85464d 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,17 +13,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import synapse.util.retryutils from twisted.internet import defer, reactor, protocol from twisted.internet.error import DNSLookupError from twisted.web.client import readBody, HTTPConnectionPool, Agent from twisted.web.http_headers import Headers from twisted.web._newclient import ResponseDone +from synapse.http import cancelled_to_request_timed_out_error from synapse.http.endpoint import matrix_federation_endpoint -from synapse.util.async import sleep -from synapse.util import logcontext import synapse.metrics +from synapse.util.async import sleep, add_timeout_to_deferred +from synapse.util import logcontext +from synapse.util.logcontext import make_deferred_yieldable +import synapse.util.retryutils from canonicaljson import encode_canonical_json @@ -38,8 +41,7 @@ import logging import random import sys import urllib -import urlparse - +from six.moves.urllib import parse as urlparse logger = logging.getLogger(__name__) outbound_logger = logging.getLogger("synapse.http.outbound") @@ -184,21 +186,20 @@ class MatrixFederationHttpClient(object): producer = body_callback(method, http_url_bytes, headers_dict) try: - def send_request(): - request_deferred = self.agent.request( - method, - url_bytes, - Headers(headers_dict), - producer - ) - - return self.clock.time_bound_deferred( - request_deferred, - time_out=timeout / 1000. if timeout else 60, - ) - - with logcontext.PreserveLoggingContext(): - response = yield send_request() + request_deferred = self.agent.request( + method, + url_bytes, + Headers(headers_dict), + producer + ) + add_timeout_to_deferred( + request_deferred, + timeout / 1000. if timeout else 60, + cancelled_to_request_timed_out_error, + ) + response = yield make_deferred_yieldable( + request_deferred, + ) log_result = "%d %s" % (response.code, response.phrase,) break @@ -286,7 +287,8 @@ class MatrixFederationHttpClient(object): headers_dict[b"Authorization"] = auth_headers @defer.inlineCallbacks - def put_json(self, destination, path, data={}, json_data_callback=None, + def put_json(self, destination, path, args={}, data={}, + json_data_callback=None, long_retries=False, timeout=None, ignore_backoff=False, backoff_on_404=False): @@ -296,6 +298,7 @@ class MatrixFederationHttpClient(object): destination (str): The remote server to send the HTTP request to. path (str): The HTTP path. + args (dict): query params data (dict): A dict containing the data that will be used as the request body. This will be encoded as JSON. json_data_callback (callable): A callable returning the dict to @@ -342,6 +345,7 @@ class MatrixFederationHttpClient(object): path, body_callback=body_callback, headers_dict={"Content-Type": ["application/json"]}, + query_bytes=encode_query_args(args), long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, @@ -373,6 +377,7 @@ class MatrixFederationHttpClient(object): giving up. None indicates no timeout. ignore_backoff (bool): true to ignore the historical backoff data and try the request anyway. + args (dict): query params Returns: Deferred: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. diff --git a/synapse/http/server.py b/synapse/http/server.py index 4b567215c8..55b9ad5251 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -37,7 +37,7 @@ from twisted.web.util import redirectTo import collections import logging import urllib -import ujson +import simplejson logger = logging.getLogger(__name__) @@ -113,6 +113,11 @@ response_db_sched_duration = metrics.register_counter( "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"] ) +# size in bytes of the response written +response_size = metrics.register_counter( + "response_size", labels=["method", "servlet", "tag"] +) + _next_request_id = 0 @@ -324,7 +329,7 @@ class JsonResource(HttpServer, resource.Resource): register_paths, so will return (possibly via Deferred) either None, or a tuple of (http code, response body). """ - if request.method == "OPTIONS": + if request.method == b"OPTIONS": return _options_handler, {} # Loop through all the registered callbacks to check if the method @@ -426,6 +431,8 @@ class RequestMetrics(object): context.db_sched_duration_ms / 1000., request.method, self.name, tag ) + response_size.inc_by(request.sentLength, request.method, self.name, tag) + class RootRedirect(resource.Resource): """Redirects the root '/' path to another path.""" @@ -461,8 +468,7 @@ def respond_with_json(request, code, json_object, send_cors=False, if canonical_json or synapse.events.USE_FROZEN_DICTS: json_bytes = encode_canonical_json(json_object) else: - # ujson doesn't like frozen_dicts. - json_bytes = ujson.dumps(json_object, ensure_ascii=False) + json_bytes = simplejson.dumps(json_object) return respond_with_json_bytes( request, code, json_bytes, @@ -489,6 +495,7 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False, request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Server", version_string) request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),)) + request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate") if send_cors: set_cors_headers(request) @@ -536,9 +543,9 @@ def finish_request(request): def _request_user_agent_is_curl(request): user_agents = request.requestHeaders.getRawHeaders( - "User-Agent", default=[] + b"User-Agent", default=[] ) for user_agent in user_agents: - if "curl" in user_agent: + if b"curl" in user_agent: return True return False diff --git a/synapse/http/site.py b/synapse/http/site.py index e422c8dfae..c8b46e1af2 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -20,7 +20,7 @@ import logging import re import time -ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') +ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') class SynapseRequest(Request): @@ -43,12 +43,12 @@ class SynapseRequest(Request): def get_redacted_uri(self): return ACCESS_TOKEN_RE.sub( - r'\1<redacted>\3', + br'\1<redacted>\3', self.uri ) def get_user_agent(self): - return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1] + return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] def started_processing(self): self.site.access_logger.info( diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 50d99d7a5c..e3b831db67 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -17,12 +17,13 @@ import logging import functools import time import gc +import platform from twisted.internet import reactor from .metric import ( CounterMetric, CallbackMetric, DistributionMetric, CacheMetric, - MemoryUsageMetric, + MemoryUsageMetric, GaugeMetric, ) from .process_collector import register_process_collector @@ -30,6 +31,7 @@ from .process_collector import register_process_collector logger = logging.getLogger(__name__) +running_on_pypy = platform.python_implementation() == 'PyPy' all_metrics = [] all_collectors = [] @@ -63,6 +65,13 @@ class Metrics(object): """ return self._register(CounterMetric, *args, **kwargs) + def register_gauge(self, *args, **kwargs): + """ + Returns: + GaugeMetric + """ + return self._register(GaugeMetric, *args, **kwargs) + def register_callback(self, *args, **kwargs): """ Returns: @@ -142,6 +151,32 @@ reactor_metrics = get_metrics_for("python.twisted.reactor") tick_time = reactor_metrics.register_distribution("tick_time") pending_calls_metric = reactor_metrics.register_distribution("pending_calls") +synapse_metrics = get_metrics_for("synapse") + +# Used to track where various components have processed in the event stream, +# e.g. federation sending, appservice sending, etc. +event_processing_positions = synapse_metrics.register_gauge( + "event_processing_positions", labels=["name"], +) + +# Used to track the current max events stream position +event_persisted_position = synapse_metrics.register_gauge( + "event_persisted_position", +) + +# Used to track the received_ts of the last event processed by various +# components +event_processing_last_ts = synapse_metrics.register_gauge( + "event_processing_last_ts", labels=["name"], +) + +# Used to track the lag processing events. This is the time difference +# between the last processed event's received_ts and the time it was +# finished being processed. +event_processing_lag = synapse_metrics.register_gauge( + "event_processing_lag", labels=["name"], +) + def runUntilCurrentTimer(func): @@ -174,6 +209,9 @@ def runUntilCurrentTimer(func): tick_time.inc_by(end - start) pending_calls_metric.inc_by(num_pending) + if running_on_pypy: + return ret + # Check if we need to do a manual GC (since its been disabled), and do # one if necessary. threshold = gc.get_threshold() @@ -206,6 +244,7 @@ try: # We manually run the GC each reactor tick so that we can get some metrics # about time spent doing GC, - gc.disable() + if not running_on_pypy: + gc.disable() except AttributeError: pass diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py index ff5aa8c0e1..fbba94e633 100644 --- a/synapse/metrics/metric.py +++ b/synapse/metrics/metric.py @@ -16,6 +16,7 @@ from itertools import chain import logging +import re logger = logging.getLogger(__name__) @@ -56,8 +57,7 @@ class BaseMetric(object): return not len(self.labels) def _render_labelvalue(self, value): - # TODO: escape backslashes, quotes and newlines - return '"%s"' % (value) + return '"%s"' % (_escape_label_value(value),) def _render_key(self, values): if self.is_scalar(): @@ -115,7 +115,7 @@ class CounterMetric(BaseMetric): # dict[list[str]]: value for each set of label values. the keys are the # label values, in the same order as the labels in self.labels. # - # (if the metric is a scalar, the (single) key is the empty list). + # (if the metric is a scalar, the (single) key is the empty tuple). self.counts = {} # Scalar metrics are never empty @@ -145,6 +145,36 @@ class CounterMetric(BaseMetric): ) +class GaugeMetric(BaseMetric): + """A metric that can go up or down + """ + + def __init__(self, *args, **kwargs): + super(GaugeMetric, self).__init__(*args, **kwargs) + + # dict[list[str]]: value for each set of label values. the keys are the + # label values, in the same order as the labels in self.labels. + # + # (if the metric is a scalar, the (single) key is the empty tuple). + self.guages = {} + + def set(self, v, *values): + if len(values) != self.dimension(): + raise ValueError( + "Expected as many values to inc() as labels (%d)" % (self.dimension()) + ) + + # TODO: should assert that the tag values are all strings + + self.guages[values] = v + + def render(self): + return flatten( + self._render_for_labels(k, self.guages[k]) + for k in sorted(self.guages.keys()) + ) + + class CallbackMetric(BaseMetric): """A metric that returns the numeric value returned by a callback whenever it is rendered. Typically this is used to implement gauges that yield the @@ -269,3 +299,29 @@ class MemoryUsageMetric(object): "process_psutil_rss:total %d" % sum_rss, "process_psutil_rss:count %d" % len_rss, ] + + +def _escape_character(m): + """Replaces a single character with its escape sequence. + + Args: + m (re.MatchObject): A match object whose first group is the single + character to replace + + Returns: + str + """ + c = m.group(1) + if c == "\\": + return "\\\\" + elif c == "\"": + return "\\\"" + elif c == "\n": + return "\\n" + return c + + +def _escape_label_value(value): + """Takes a label value and escapes quotes, newlines and backslashes + """ + return re.sub(r"([\n\"\\])", _escape_character, value) diff --git a/synapse/notifier.py b/synapse/notifier.py index ef042681bc..8355c7d621 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -14,14 +14,17 @@ # limitations under the License. from twisted.internet import defer + from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state -from synapse.util import DeferredTimedOutError from synapse.util.logutils import log_function -from synapse.util.async import ObservableDeferred -from synapse.util.logcontext import PreserveLoggingContext, preserve_fn +from synapse.util.async import ( + ObservableDeferred, add_timeout_to_deferred, + DeferredTimeoutError, +) +from synapse.util.logcontext import PreserveLoggingContext, run_in_background from synapse.util.metrics import Measure from synapse.types import StreamToken from synapse.visibility import filter_events_for_client @@ -144,6 +147,7 @@ class _NotifierUserStream(object): class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))): def __nonzero__(self): return bool(self.events) + __bool__ = __nonzero__ # python3 class Notifier(object): @@ -250,9 +254,7 @@ class Notifier(object): def _on_new_room_event(self, event, room_stream_id, extra_users=[]): """Notify any user streams that are interested in this room event""" # poke any interested application service. - preserve_fn(self.appservice_handler.notify_interested_services)( - room_stream_id - ) + run_in_background(self._notify_app_services, room_stream_id) if self.federation_sender: self.federation_sender.notify_new_events(room_stream_id) @@ -266,6 +268,13 @@ class Notifier(object): rooms=[event.room_id], ) + @defer.inlineCallbacks + def _notify_app_services(self, room_stream_id): + try: + yield self.appservice_handler.notify_interested_services(room_stream_id) + except Exception: + logger.exception("Error notifying application services of event") + def on_new_event(self, stream_key, new_token, users=[], rooms=[]): """ Used to inform listeners that something has happend event wise. @@ -330,11 +339,12 @@ class Notifier(object): # Now we wait for the _NotifierUserStream to be told there # is a new token. listener = user_stream.new_listener(prev_token) + add_timeout_to_deferred( + listener.deferred, + (end_time - now) / 1000., + ) with PreserveLoggingContext(): - yield self.clock.time_bound_deferred( - listener.deferred, - time_out=(end_time - now) / 1000. - ) + yield listener.deferred current_token = user_stream.current_token @@ -345,7 +355,7 @@ class Notifier(object): # Update the prev_token to the current_token since nothing # has happened between the old prev_token and the current_token prev_token = current_token - except DeferredTimedOutError: + except DeferredTimeoutError: break except defer.CancelledError: break @@ -550,13 +560,14 @@ class Notifier(object): if end_time <= now: break + add_timeout_to_deferred( + listener.deferred.addTimeout, + (end_time - now) / 1000., + ) try: with PreserveLoggingContext(): - yield self.clock.time_bound_deferred( - listener.deferred, - time_out=(end_time - now) / 1000. - ) - except DeferredTimedOutError: + yield listener.deferred + except DeferredTimeoutError: break except defer.CancelledError: break diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 58df98a793..ba7286cb72 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -77,10 +77,13 @@ class EmailPusher(object): @defer.inlineCallbacks def on_started(self): if self.mailer is not None: - self.throttle_params = yield self.store.get_throttle_params_by_room( - self.pusher_id - ) - yield self._process() + try: + self.throttle_params = yield self.store.get_throttle_params_by_room( + self.pusher_id + ) + yield self._process() + except Exception: + logger.exception("Error starting email pusher") def on_stop(self): if self.timed_call: diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 2cbac571b8..b077e1a446 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -18,8 +18,8 @@ import logging from twisted.internet import defer, reactor from twisted.internet.error import AlreadyCalled, AlreadyCancelled -import push_rule_evaluator -import push_tools +from . import push_rule_evaluator +from . import push_tools import synapse from synapse.push import PusherConfigException from synapse.util.logcontext import LoggingContext @@ -94,7 +94,10 @@ class HttpPusher(object): @defer.inlineCallbacks def on_started(self): - yield self._process() + try: + yield self._process() + except Exception: + logger.exception("Error starting http pusher") @defer.inlineCallbacks def on_new_notifications(self, min_stream_ordering, max_stream_ordering): diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index 71576330a9..5aa6667e91 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from httppusher import HttpPusher +from .httppusher import HttpPusher import logging logger = logging.getLogger(__name__) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 134e89b371..750d11ca38 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -from .pusher import PusherFactory -from synapse.util.logcontext import make_deferred_yieldable, preserve_fn +from synapse.push.pusher import PusherFactory from synapse.util.async import run_on_reactor - -import logging +from synapse.util.logcontext import make_deferred_yieldable, run_in_background logger = logging.getLogger(__name__) @@ -137,12 +137,15 @@ class PusherPool: if u in self.pushers: for p in self.pushers[u].values(): deferreds.append( - preserve_fn(p.on_new_notifications)( - min_stream_id, max_stream_id + run_in_background( + p.on_new_notifications, + min_stream_id, max_stream_id, ) ) - yield make_deferred_yieldable(defer.gatherResults(deferreds)) + yield make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True), + ) except Exception: logger.exception("Exception in pusher on_new_notifications") @@ -164,10 +167,15 @@ class PusherPool: if u in self.pushers: for p in self.pushers[u].values(): deferreds.append( - preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id) + run_in_background( + p.on_new_receipts, + min_stream_id, max_stream_id, + ) ) - yield make_deferred_yieldable(defer.gatherResults(deferreds)) + yield make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True), + ) except Exception: logger.exception("Exception in pusher on_new_receipts") @@ -207,7 +215,7 @@ class PusherPool: if appid_pushkey in byuser: byuser[appid_pushkey].on_stop() byuser[appid_pushkey] = p - preserve_fn(p.on_started)() + run_in_background(p.on_started) logger.info("Started pushers") diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 91179ce532..216db4d164 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -1,5 +1,6 @@ # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,28 +19,43 @@ from distutils.version import LooseVersion logger = logging.getLogger(__name__) +# this dict maps from python package name to a list of modules we expect it to +# provide. +# +# the key is a "requirement specifier", as used as a parameter to `pip +# install`[1], or an `install_requires` argument to `setuptools.setup` [2]. +# +# the value is a sequence of strings; each entry should be the name of the +# python module, optionally followed by a version assertion which can be either +# ">=<ver>" or "==<ver>". +# +# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers. +# [2] https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-dependencies REQUIREMENTS = { "jsonschema>=2.5.1": ["jsonschema>=2.5.1"], "frozendict>=0.4": ["frozendict"], "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"], - "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], + "canonicaljson>=1.1.3": ["canonicaljson>=1.1.3"], "signedjson>=1.0.0": ["signedjson>=1.0.0"], "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], "Twisted>=16.0.0": ["twisted>=16.0.0"], - "pyopenssl>=0.14": ["OpenSSL>=0.14"], + + # We use crypto.get_elliptic_curve which is only supported in >=0.15 + "pyopenssl>=0.15": ["OpenSSL>=0.15"], + "pyyaml": ["yaml"], "pyasn1": ["pyasn1"], "daemonize": ["daemonize"], "bcrypt": ["bcrypt>=3.1.0"], "pillow": ["PIL"], "pydenticon": ["pydenticon"], - "ujson": ["ujson"], "blist": ["blist"], "pysaml2>=3.0.0": ["saml2>=3.0.0"], "pymacaroons-pynacl": ["pymacaroons"], "msgpack-python>=0.3.0": ["msgpack"], "phonenumbers>=8.2.0": ["phonenumbers"], + "six": ["six"], } CONDITIONAL_REQUIREMENTS = { "web_client": { diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index bbe2f967b7..a9baa2c1c3 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -23,7 +23,6 @@ from synapse.events.snapshot import EventContext from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.util.async import sleep from synapse.util.caches.response_cache import ResponseCache -from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.metrics import Measure from synapse.types import Requester, UserID @@ -115,20 +114,15 @@ class ReplicationSendEventRestServlet(RestServlet): self.clock = hs.get_clock() # The responses are tiny, so we may as well cache them for a while - self.response_cache = ResponseCache(hs, timeout_ms=30 * 60 * 1000) + self.response_cache = ResponseCache(hs, "send_event", timeout_ms=30 * 60 * 1000) def on_PUT(self, request, event_id): - result = self.response_cache.get(event_id) - if not result: - result = self.response_cache.set( - event_id, - self._handle_request(request) - ) - else: - logger.warn("Returning cached response") - return make_deferred_yieldable(result) - - @preserve_fn + return self.response_cache.wrap( + event_id, + self._handle_request, + request + ) + @defer.inlineCallbacks def _handle_request(self, request): with Measure(self.clock, "repl_send_event_parse"): diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 171227cce2..12aac3cc6b 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -19,11 +19,13 @@ allowed to be sent by which side. """ import logging -import ujson as json +import simplejson logger = logging.getLogger(__name__) +_json_encoder = simplejson.JSONEncoder(namedtuple_as_object=False) + class Command(object): """The base command class. @@ -100,14 +102,14 @@ class RdataCommand(Command): return cls( stream_name, None if token == "batch" else int(token), - json.loads(row_json) + simplejson.loads(row_json) ) def to_line(self): return " ".join(( self.stream_name, str(self.token) if self.token is not None else "batch", - json.dumps(self.row), + _json_encoder.encode(self.row), )) @@ -298,10 +300,12 @@ class InvalidateCacheCommand(Command): def from_line(cls, line): cache_func, keys_json = line.split(" ", 1) - return cls(cache_func, json.loads(keys_json)) + return cls(cache_func, simplejson.loads(keys_json)) def to_line(self): - return " ".join((self.cache_func, json.dumps(self.keys))) + return " ".join(( + self.cache_func, _json_encoder.encode(self.keys), + )) class UserIpCommand(Command): @@ -325,14 +329,14 @@ class UserIpCommand(Command): def from_line(cls, line): user_id, jsn = line.split(" ", 1) - access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) + access_token, ip, user_agent, device_id, last_seen = simplejson.loads(jsn) return cls( user_id, access_token, ip, user_agent, device_id, last_seen ) def to_line(self): - return self.user_id + " " + json.dumps(( + return self.user_id + " " + _json_encoder.encode(( self.access_token, self.ip, self.user_agent, self.device_id, self.last_seen, )) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 0a9a290af4..d7d38464b2 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -53,12 +53,12 @@ from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure -from commands import ( +from .commands import ( COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand, NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand, ) -from streams import STREAMS_MAP +from .streams import STREAMS_MAP from synapse.util.stringutils import random_string from synapse.metrics.metric import CounterMetric diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 786c3fe864..a41af4fd6c 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -18,8 +18,8 @@ from twisted.internet import defer, reactor from twisted.internet.protocol import Factory -from streams import STREAMS_MAP, FederationStream -from protocol import ServerReplicationStreamProtocol +from .streams import STREAMS_MAP, FederationStream +from .protocol import ServerReplicationStreamProtocol from synapse.util.metrics import Measure, measure_func diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 303419d281..efd5c9873d 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -168,11 +168,24 @@ class PurgeHistoryRestServlet(ClientV1RestServlet): yield self.store.find_first_stream_ordering_after_ts(ts) ) - (_, depth, _) = ( + room_event_after_stream_ordering = ( yield self.store.get_room_event_after_stream_ordering( room_id, stream_ordering, ) ) + if room_event_after_stream_ordering: + (_, depth, _) = room_event_after_stream_ordering + else: + logger.warn( + "[purge] purging events not possible: No event found " + "(received_ts %i => stream_ordering %i)", + ts, stream_ordering, + ) + raise SynapseError( + 404, + "there is no event to be purged", + errcode=Codes.NOT_FOUND, + ) logger.info( "[purge] purging up to depth %i (received_ts %i => " "stream_ordering %i)", diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index c7aa0bbf59..197335d7aa 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet): """A base Synapse REST Servlet for the client version 1 API. """ + # This subclass was presumably created to allow the auth for the v1 + # protocol version to be different, however this behaviour was removed. + # it may no longer be necessary + def __init__(self, hs): """ Args: @@ -59,5 +63,5 @@ class ClientV1RestServlet(RestServlet): """ self.hs = hs self.builder_factory = hs.get_event_builder_factory() - self.auth = hs.get_v1auth() + self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs.get_clock()) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 45844aa2d2..34df5be4e9 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -25,7 +25,7 @@ from .base import ClientV1RestServlet, client_path_patterns import simplejson as json import urllib -import urlparse +from six.moves.urllib import parse as urlparse import logging from saml2 import BINDING_HTTP_POST diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index ca49955935..e092158cb7 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -44,7 +44,10 @@ class LogoutRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) except AuthError: # this implies the access token has already been deleted. - pass + defer.returnValue((401, { + "errcode": "M_UNKNOWN_TOKEN", + "error": "Access Token unknown or expired" + })) else: if requester.device_id is None: # the acccess token wasn't associated with a device. diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 1819a560cb..0206e664c1 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -150,7 +150,7 @@ class PushersRemoveRestServlet(RestServlet): super(RestServlet, self).__init__() self.hs = hs self.notifier = hs.get_notifier() - self.auth = hs.get_v1auth() + self.auth = hs.get_auth() self.pusher_pool = self.hs.get_pusherpool() @defer.inlineCallbacks diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 5c5fa8f7ab..9b3022e0b0 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -30,6 +30,8 @@ from hashlib import sha1 import hmac import logging +from six import string_types + logger = logging.getLogger(__name__) @@ -333,11 +335,11 @@ class RegisterRestServlet(ClientV1RestServlet): def _do_shared_secret(self, request, register_json, session): yield run_on_reactor() - if not isinstance(register_json.get("mac", None), basestring): + if not isinstance(register_json.get("mac", None), string_types): raise SynapseError(400, "Expected mac.") - if not isinstance(register_json.get("user", None), basestring): + if not isinstance(register_json.get("user", None), string_types): raise SynapseError(400, "Expected 'user' key.") - if not isinstance(register_json.get("password", None), basestring): + if not isinstance(register_json.get("password", None), string_types): raise SynapseError(400, "Expected 'password' key.") if not self.hs.config.registration_shared_secret: @@ -348,9 +350,9 @@ class RegisterRestServlet(ClientV1RestServlet): admin = register_json.get("admin", None) # Its important to check as we use null bytes as HMAC field separators - if "\x00" in user: + if b"\x00" in user: raise SynapseError(400, "Invalid user") - if "\x00" in password: + if b"\x00" in password: raise SynapseError(400, "Invalid password") # str() because otherwise hmac complains that 'unicode' does not @@ -358,14 +360,14 @@ class RegisterRestServlet(ClientV1RestServlet): got_mac = str(register_json["mac"]) want_mac = hmac.new( - key=self.hs.config.registration_shared_secret, + key=self.hs.config.registration_shared_secret.encode(), digestmod=sha1, ) want_mac.update(user) - want_mac.update("\x00") + want_mac.update(b"\x00") want_mac.update(password) - want_mac.update("\x00") - want_mac.update("admin" if admin else "notadmin") + want_mac.update(b"\x00") + want_mac.update(b"admin" if admin else b"notadmin") want_mac = want_mac.hexdigest() if compare_digest(want_mac, got_mac): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f8999d64d7..fcf9c9ab44 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -28,9 +28,10 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, parse_integer ) +from six.moves.urllib import parse as urlparse + import logging -import urllib -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) @@ -165,17 +166,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet): content=content, ) else: - event, context = yield self.event_creation_hander.create_event( + event = yield self.event_creation_hander.create_and_send_nonmember_event( requester, event_dict, - token_id=requester.access_token_id, txn_id=txn_id, ) - yield self.event_creation_hander.send_nonmember_event( - requester, event, context, - ) - ret = {} if event: ret = {"event_id": event.event_id} @@ -438,7 +434,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): as_client_event = "raw" not in request.args filter_bytes = request.args.get("filter", None) if filter_bytes: - filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8") + filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8") event_filter = Filter(json.loads(filter_json)) else: event_filter = None @@ -655,7 +651,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet): content=event_content, ) - defer.returnValue((200, {})) + return_value = {} + + if membership_action == "join": + return_value["room_id"] = room_id + + defer.returnValue((200, return_value)) def _has_3pid_invite_keys(self, content): for key in {"id_server", "medium", "address"}: @@ -718,8 +719,8 @@ class RoomTypingRestServlet(ClientV1RestServlet): def on_PUT(self, request, room_id, user_id): requester = yield self.auth.get_user_by_req(request) - room_id = urllib.unquote(room_id) - target_user = UserID.from_string(urllib.unquote(user_id)) + room_id = urlparse.unquote(room_id) + target_user = UserID.from_string(urlparse.unquote(user_id)) content = parse_json_object_from_request(request) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index f762dbfa9a..3bb1ec2af6 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -401,6 +402,32 @@ class GroupInvitedUsersServlet(RestServlet): defer.returnValue((200, result)) +class GroupSettingJoinPolicyServlet(RestServlet): + """Set group join policy + """ + PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") + + def __init__(self, hs): + super(GroupSettingJoinPolicyServlet, self).__init__() + self.auth = hs.get_auth() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_PUT(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + + result = yield self.groups_handler.set_group_join_policy( + group_id, + requester_user_id, + content, + ) + + defer.returnValue((200, result)) + + class GroupCreateServlet(RestServlet): """Create a group """ @@ -738,6 +765,7 @@ def register_servlets(hs, http_server): GroupInvitedUsersServlet(hs).register(http_server) GroupUsersServlet(hs).register(http_server) GroupRoomServlet(hs).register(http_server) + GroupSettingJoinPolicyServlet(hs).register(http_server) GroupCreateServlet(hs).register(http_server) GroupAdminRoomsServlet(hs).register(http_server) GroupAdminRoomsConfigServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 0ba62bddc1..5cab00aea9 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -20,7 +20,6 @@ import synapse import synapse.types from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType -from synapse.types import RoomID, RoomAlias from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string @@ -36,6 +35,8 @@ from hashlib import sha1 from synapse.util.async import run_on_reactor from synapse.util.ratelimitutils import FederationRateLimiter +from six import string_types + # We ought to be using hmac.compare_digest() but on older pythons it doesn't # exist. It's a _really minor_ security flaw to use plain string comparison @@ -211,14 +212,14 @@ class RegisterRestServlet(RestServlet): # in sessions. Pull out the username/password provided to us. desired_password = None if 'password' in body: - if (not isinstance(body['password'], basestring) or + if (not isinstance(body['password'], string_types) or len(body['password']) > 512): raise SynapseError(400, "Invalid password") desired_password = body["password"] desired_username = None if 'username' in body: - if (not isinstance(body['username'], basestring) or + if (not isinstance(body['username'], string_types) or len(body['username']) > 512): raise SynapseError(400, "Invalid username") desired_username = body['username'] @@ -244,7 +245,7 @@ class RegisterRestServlet(RestServlet): access_token = get_access_token_from_request(request) - if isinstance(desired_username, basestring): + if isinstance(desired_username, string_types): result = yield self._do_appservice_registration( desired_username, access_token, body ) @@ -405,14 +406,6 @@ class RegisterRestServlet(RestServlet): generate_token=False, ) - # auto-join the user to any rooms we're supposed to dump them into - fake_requester = synapse.types.create_requester(registered_user_id) - for r in self.hs.config.auto_join_rooms: - try: - yield self._join_user_to_room(fake_requester, r) - except Exception as e: - logger.error("Failed to join new user to %r: %r", r, e) - # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) self.auth_handler.set_session_data( @@ -446,29 +439,6 @@ class RegisterRestServlet(RestServlet): return 200, {} @defer.inlineCallbacks - def _join_user_to_room(self, requester, room_identifier): - room_id = None - if RoomID.is_valid(room_identifier): - room_id = room_identifier - elif RoomAlias.is_valid(room_identifier): - room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = ( - yield self.room_member_handler.lookup_room_alias(room_alias) - ) - room_id = room_id.to_string() - else: - raise SynapseError(400, "%s was not legal room ID or room alias" % ( - room_identifier, - )) - - yield self.room_member_handler.update_membership( - requester=requester, - target=requester.user, - room_id=room_id, - action="join", - ) - - @defer.inlineCallbacks def _do_appservice_registration(self, username, as_token, body): user_id = yield self.registration_handler.appservice_register( username, as_token @@ -496,7 +466,7 @@ class RegisterRestServlet(RestServlet): # includes the password and admin flag in the hashed text. Why are # these different? want_mac = hmac.new( - key=self.hs.config.registration_shared_secret, + key=self.hs.config.registration_shared_secret.encode(), msg=user, digestmod=sha1, ).hexdigest() diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a0a8e4b8e4..eb91c0b293 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -33,7 +33,7 @@ from ._base import set_timeline_upper_limit import itertools import logging -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index e7ac01da01..c0d2f06855 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -28,7 +28,7 @@ import os import logging import urllib -import urlparse +from six.moves.urllib import parse as urlparse logger = logging.getLogger(__name__) @@ -143,6 +143,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam respond_404(request) return + logger.debug("Responding to media request with responder %s") add_file_headers(request, media_type, file_size, upload_name) with responder: yield responder.write_to_consumer(request) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index bb79599379..9800ce7581 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -47,7 +47,7 @@ import shutil import cgi import logging -import urlparse +from six.moves.urllib import parse as urlparse logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 83471b3173..d23fe10b07 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -16,6 +16,8 @@ from twisted.internet import defer, threads from twisted.protocols.basic import FileSender +import six + from ._base import Responder from synapse.util.file_consumer import BackgroundFileConsumer @@ -119,7 +121,7 @@ class MediaStorage(object): os.remove(fname) except Exception: pass - raise t, v, tb + six.reraise(t, v, tb) if not finished_called: raise Exception("Finished callback not called") @@ -253,7 +255,9 @@ class FileResponder(Responder): self.open_file = open_file def write_to_consumer(self, consumer): - return FileSender().beginFileTransfer(self.open_file, consumer) + return make_deferred_yieldable( + FileSender().beginFileTransfer(self.open_file, consumer) + ) def __exit__(self, exc_type, exc_val, exc_tb): self.open_file.close() diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 31fe7aa75c..9290d7946f 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -23,7 +23,7 @@ import re import shutil import sys import traceback -import ujson as json +import simplejson as json import urlparse from twisted.web.server import NOT_DONE_YET @@ -35,7 +35,7 @@ from ._base import FileInfo from synapse.api.errors import ( SynapseError, Codes, ) -from synapse.util.logcontext import preserve_fn, make_deferred_yieldable +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.stringutils import random_string from synapse.util.caches.expiringcache import ExpiringCache from synapse.http.client import SpiderHttpClient @@ -144,7 +144,8 @@ class PreviewUrlResource(Resource): observable = self._cache.get(url) if not observable: - download = preserve_fn(self._do_preview)( + download = run_in_background( + self._do_preview, url, requester.user, ts, ) observable = ObservableDeferred( diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index c188192f2b..0252afd9d3 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -18,7 +18,7 @@ from twisted.internet import defer, threads from .media_storage import FileResponder from synapse.config._base import Config -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import run_in_background import logging import os @@ -87,7 +87,12 @@ class StorageProviderWrapper(StorageProvider): return self.backend.store_file(path, file_info) else: # TODO: Handle errors. - preserve_fn(self.backend.store_file)(path, file_info) + def store(): + try: + return self.backend.store_file(path, file_info) + except Exception: + logger.exception("Error storing file") + run_in_background(store) return defer.succeed(None) def fetch(self, path, file_info): diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index f6f498cdc5..a31e75cb46 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -81,15 +81,15 @@ class UploadResource(Resource): headers = request.requestHeaders if headers.hasHeader("Content-Type"): - media_type = headers.getRawHeaders("Content-Type")[0] + media_type = headers.getRawHeaders(b"Content-Type")[0] else: raise SynapseError( msg="Upload request missing 'Content-Type'", code=400, ) - # if headers.hasHeader("Content-Disposition"): - # disposition = headers.getRawHeaders("Content-Disposition")[0] + # if headers.hasHeader(b"Content-Disposition"): + # disposition = headers.getRawHeaders(b"Content-Disposition")[0] # TODO(markjh): parse content-dispostion content_uri = yield self.media_repo.create_content( diff --git a/synapse/server.py b/synapse/server.py index cd0c1a51be..ebdea6b0c4 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -105,7 +105,6 @@ class HomeServer(object): 'federation_client', 'federation_server', 'handlers', - 'v1auth', 'auth', 'state_handler', 'state_resolution_handler', @@ -225,15 +224,6 @@ class HomeServer(object): def build_simple_http_client(self): return SimpleHttpClient(self) - def build_v1auth(self): - orf = Auth(self) - # Matrix spec makes no reference to what HTTP status code is returned, - # but the V1 API uses 403 where it means 401, and the webclient - # relies on this behaviour, so V1 gets its own copy of the auth - # with backwards compat behaviour. - orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403 - return orf - def build_state_handler(self): return StateHandler(self) diff --git a/synapse/state.py b/synapse/state.py index 932f602508..26093c8434 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -132,7 +132,7 @@ class StateHandler(object): state_map = yield self.store.get_events(state.values(), get_prev_content=False) state = { - key: state_map[e_id] for key, e_id in state.items() if e_id in state_map + key: state_map[e_id] for key, e_id in state.iteritems() if e_id in state_map } defer.returnValue(state) @@ -378,7 +378,7 @@ class StateHandler(object): new_state = resolve_events_with_state_map(state_set_ids, state_map) new_state = { - key: state_map[ev_id] for key, ev_id in new_state.items() + key: state_map[ev_id] for key, ev_id in new_state.iteritems() } return new_state @@ -458,15 +458,15 @@ class StateResolutionHandler(object): # build a map from state key to the event_ids which set that state. # dict[(str, str), set[str]) state = {} - for st in state_groups_ids.values(): - for key, e_id in st.items(): + for st in state_groups_ids.itervalues(): + for key, e_id in st.iteritems(): state.setdefault(key, set()).add(e_id) # build a map from state key to the event_ids which set that state, # including only those where there are state keys in conflict. conflicted_state = { k: list(v) - for k, v in state.items() + for k, v in state.iteritems() if len(v) > 1 } @@ -480,36 +480,37 @@ class StateResolutionHandler(object): ) else: new_state = { - key: e_ids.pop() for key, e_ids in state.items() + key: e_ids.pop() for key, e_ids in state.iteritems() } - # if the new state matches any of the input state groups, we can - # use that state group again. Otherwise we will generate a state_id - # which will be used as a cache key for future resolutions, but - # not get persisted. - state_group = None - new_state_event_ids = frozenset(new_state.values()) - for sg, events in state_groups_ids.items(): - if new_state_event_ids == frozenset(e_id for e_id in events): - state_group = sg - break - - # TODO: We want to create a state group for this set of events, to - # increase cache hits, but we need to make sure that it doesn't - # end up as a prev_group without being added to the database - - prev_group = None - delta_ids = None - for old_group, old_ids in state_groups_ids.iteritems(): - if not set(new_state) - set(old_ids): - n_delta_ids = { - k: v - for k, v in new_state.iteritems() - if old_ids.get(k) != v - } - if not delta_ids or len(n_delta_ids) < len(delta_ids): - prev_group = old_group - delta_ids = n_delta_ids + with Measure(self.clock, "state.create_group_ids"): + # if the new state matches any of the input state groups, we can + # use that state group again. Otherwise we will generate a state_id + # which will be used as a cache key for future resolutions, but + # not get persisted. + state_group = None + new_state_event_ids = frozenset(new_state.itervalues()) + for sg, events in state_groups_ids.iteritems(): + if new_state_event_ids == frozenset(e_id for e_id in events): + state_group = sg + break + + # TODO: We want to create a state group for this set of events, to + # increase cache hits, but we need to make sure that it doesn't + # end up as a prev_group without being added to the database + + prev_group = None + delta_ids = None + for old_group, old_ids in state_groups_ids.iteritems(): + if not set(new_state) - set(old_ids): + n_delta_ids = { + k: v + for k, v in new_state.iteritems() + if old_ids.get(k) != v + } + if not delta_ids or len(n_delta_ids) < len(delta_ids): + prev_group = old_group + delta_ids = n_delta_ids cache = _StateCacheEntry( state=new_state, @@ -702,7 +703,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ auth_events = { key: state_map[ev_id] - for key, ev_id in auth_event_ids.items() + for key, ev_id in auth_event_ids.iteritems() if ev_id in state_map } @@ -740,7 +741,7 @@ def _resolve_state_events(conflicted_state, auth_events): auth_events.update(resolved_state) - for key, events in conflicted_state.items(): + for key, events in conflicted_state.iteritems(): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) resolved_state[key] = _resolve_auth_events( @@ -750,7 +751,7 @@ def _resolve_state_events(conflicted_state, auth_events): auth_events.update(resolved_state) - for key, events in conflicted_state.items(): + for key, events in conflicted_state.iteritems(): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) resolved_state[key] = _resolve_auth_events( @@ -760,7 +761,7 @@ def _resolve_state_events(conflicted_state, auth_events): auth_events.update(resolved_state) - for key, events in conflicted_state.items(): + for key, events in conflicted_state.iteritems(): if key not in resolved_state: logger.debug("Resolving conflicted state %r:%r", key, events) resolved_state[key] = _resolve_normal_events( diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index de00cae447..8cdfd50f90 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.storage.devices import DeviceStore from .appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore @@ -244,13 +242,12 @@ class DataStore(RoomMemberStore, RoomStore, return [UserPresenceState(**row) for row in rows] - @defer.inlineCallbacks def count_daily_users(self): """ Counts the number of users who used this homeserver in the last 24 hours. """ def _count_users(txn): - yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24), + yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) sql = """ SELECT COALESCE(count(*), 0) FROM ( @@ -264,8 +261,91 @@ class DataStore(RoomMemberStore, RoomStore, count, = txn.fetchone() return count - ret = yield self.runInteraction("count_users", _count_users) - defer.returnValue(ret) + return self.runInteraction("count_users", _count_users) + + def count_r30_users(self): + """ + Counts the number of 30 day retained users, defined as:- + * Users who have created their accounts more than 30 days ago + * Where last seen at most 30 days ago + * Where account creation and last_seen are > 30 days apart + + Returns counts globaly for a given user as well as breaking + by platform + """ + def _count_r30_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + thirty_days_ago_in_secs = now - thirty_days_in_secs + + sql = """ + SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT + users.name, platform, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM user_ips + ) uip + ON users.name = uip.user_id + AND users.appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, platform, users.creation_ts + ) u GROUP BY platform + """ + + results = {} + txn.execute(sql, (thirty_days_ago_in_secs, + thirty_days_ago_in_secs)) + + for row in txn: + if row[0] is 'unknown': + pass + results[row[0]] = row[1] + + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT users.name, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen + FROM user_ips + ) uip + ON users.name = uip.user_id + AND appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, users.creation_ts + ) u + """ + + txn.execute(sql, (thirty_days_ago_in_secs, + thirty_days_ago_in_secs)) + + count, = txn.fetchone() + results['all'] = count + + return results + + return self.runInteraction("count_r30_users", _count_r30_users) def get_users(self): """Function to reterive a list of users in users table. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 2fbebd4907..2262776ab2 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -376,7 +376,7 @@ class SQLBaseStore(object): Returns: A list of dicts where the key is the column header. """ - col_headers = list(intern(column[0]) for column in cursor.description) + col_headers = list(intern(str(column[0])) for column in cursor.description) results = list( dict(zip(col_headers, row)) for row in cursor ) diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index e70c9423e3..f83ff0454a 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -23,7 +23,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks import abc -import ujson as json +import simplejson as json import logging logger = logging.getLogger(__name__) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index c88759bf2c..8af325a9f5 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -19,7 +19,7 @@ from . import engines from twisted.internet import defer -import ujson as json +import simplejson as json import logging logger = logging.getLogger(__name__) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index a03d1d6104..7b44dae0fc 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -48,6 +48,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): columns=["user_id", "device_id", "last_seen"], ) + self.register_background_index_update( + "user_ips_last_seen_index", + index_name="user_ips_last_seen", + table="user_ips", + columns=["user_id", "last_seen"], + ) + # (user_id, access_token, ip) -> (user_agent, device_id, last_seen) self._batch_row_update = {} diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 548e795daf..a879e5bfc1 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -import ujson +import simplejson from twisted.internet import defer @@ -85,7 +85,7 @@ class DeviceInboxStore(BackgroundUpdateStore): ) rows = [] for destination, edu in remote_messages_by_destination.items(): - edu_json = ujson.dumps(edu) + edu_json = simplejson.dumps(edu) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) @@ -177,7 +177,7 @@ class DeviceInboxStore(BackgroundUpdateStore): " WHERE user_id = ?" ) txn.execute(sql, (user_id,)) - message_json = ujson.dumps(messages_by_device["*"]) + message_json = simplejson.dumps(messages_by_device["*"]) for row in txn: # Add the message for all devices for this user on this # server. @@ -199,7 +199,7 @@ class DeviceInboxStore(BackgroundUpdateStore): # Only insert into the local inbox if the device exists on # this server device = row[0] - message_json = ujson.dumps(messages_by_device[device]) + message_json = simplejson.dumps(messages_by_device[device]) messages_json_for_user[device] = message_json if messages_json_for_user: @@ -253,7 +253,7 @@ class DeviceInboxStore(BackgroundUpdateStore): messages = [] for row in txn: stream_pos = row[0] - messages.append(ujson.loads(row[1])) + messages.append(simplejson.loads(row[1])) if len(messages) < limit: stream_pos = current_stream_id return (messages, stream_pos) @@ -389,7 +389,7 @@ class DeviceInboxStore(BackgroundUpdateStore): messages = [] for row in txn: stream_pos = row[0] - messages.append(ujson.loads(row[1])) + messages.append(simplejson.loads(row[1])) if len(messages) < limit: stream_pos = current_stream_id return (messages, stream_pos) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index bd2effdf34..712106b83a 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import ujson as json +import simplejson as json from twisted.internet import defer diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 2cebb203c6..ff8538ddf8 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.util.caches.descriptors import cached from canonicaljson import encode_canonical_json -import ujson as json +import simplejson as json from ._base import SQLBaseStore diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 338b495611..8c868ece75 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -18,6 +18,7 @@ from .postgres import PostgresEngine from .sqlite3 import Sqlite3Engine import importlib +import platform SUPPORTED_MODULE = { @@ -31,6 +32,10 @@ def create_engine(database_config): engine_class = SUPPORTED_MODULE.get(name, None) if engine_class: + # pypy requires psycopg2cffi rather than psycopg2 + if (name == "psycopg2" and + platform.python_implementation() == "PyPy"): + name = "psycopg2cffi" module = importlib.import_module(name) return engine_class(module, database_config) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 00ee82d300..8fbf7ffba7 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random from twisted.internet import defer @@ -24,7 +25,9 @@ from synapse.util.caches.descriptors import cached from unpaddedbase64 import encode_base64 import logging -from Queue import PriorityQueue, Empty +from six.moves.queue import PriorityQueue, Empty + +from six.moves import range logger = logging.getLogger(__name__) @@ -78,7 +81,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, front_list = list(front) chunks = [ front_list[x:x + 100] - for x in xrange(0, len(front), 100) + for x in range(0, len(front), 100) ] for chunk in chunks: txn.execute( @@ -133,7 +136,47 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, retcol="event_id", ) + @defer.inlineCallbacks + def get_prev_events_for_room(self, room_id): + """ + Gets a subset of the current forward extremities in the given room. + + Limits the result to 10 extremities, so that we can avoid creating + events which refer to hundreds of prev_events. + + Args: + room_id (str): room_id + + Returns: + Deferred[list[(str, dict[str, str], int)]] + for each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + """ + res = yield self.get_latest_event_ids_and_hashes_in_room(room_id) + if len(res) > 10: + # Sort by reverse depth, so we point to the most recent. + res.sort(key=lambda a: -a[2]) + + # we use half of the limit for the actual most recent events, and + # the other half to randomly point to some of the older events, to + # make sure that we don't completely ignore the older events. + res = res[0:5] + random.sample(res[5:], 5) + + defer.returnValue(res) + def get_latest_event_ids_and_hashes_in_room(self, room_id): + """ + Gets the current forward extremities in the given room + + Args: + room_id (str): room_id + + Returns: + Deferred[list[(str, dict[str, str], int)]] + for each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + """ + return self.runInteraction( "get_latest_event_ids_and_hashes_in_room", self._get_latest_event_ids_and_hashes_in_room, @@ -182,22 +225,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, room_id, ) - @defer.inlineCallbacks - def get_max_depth_of_events(self, event_ids): - sql = ( - "SELECT MAX(depth) FROM events WHERE event_id IN (%s)" - ) % (",".join(["?"] * len(event_ids)),) - - rows = yield self._execute( - "get_max_depth_of_events", None, - sql, *event_ids - ) - - if rows: - defer.returnValue(rows[0][0]) - else: - defer.returnValue(1) - def _get_min_depth_interaction(self, txn, room_id): min_depth = self._simple_select_one_onecol_txn( txn, diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 01f8339825..c22762eb5c 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -22,7 +22,7 @@ from synapse.types import RoomStreamToken from .stream import lower_bound import logging -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) @@ -448,6 +448,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) + @defer.inlineCallbacks def remove_push_actions_from_staging(self, event_id): """Called if we failed to persist the event to ensure that stale push actions don't build up in the DB @@ -456,13 +457,22 @@ class EventPushActionsWorkerStore(SQLBaseStore): event_id (str) """ - return self._simple_delete( - table="event_push_actions_staging", - keyvalues={ - "event_id": event_id, - }, - desc="remove_push_actions_from_staging", - ) + try: + res = yield self._simple_delete( + table="event_push_actions_staging", + keyvalues={ + "event_id": event_id, + }, + desc="remove_push_actions_from_staging", + ) + defer.returnValue(res) + except Exception: + # this method is called from an exception handler, so propagating + # another exception here really isn't helpful - there's nothing + # the caller can do about it. Just log the exception and move on. + logger.exception( + "Error removing push actions after event persistence failure", + ) @defer.inlineCallbacks def _find_stream_orderings_for_times(self): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 3890878170..05cde96afc 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -14,15 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.events_worker import EventsWorkerStore +from collections import OrderedDict, deque, namedtuple +from functools import wraps +import itertools +import logging +import simplejson as json from twisted.internet import defer -from synapse.events import USE_FROZEN_DICTS - +from synapse.storage.events_worker import EventsWorkerStore from synapse.util.async import ObservableDeferred +from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import ( - PreserveLoggingContext, make_deferred_yieldable + PreserveLoggingContext, make_deferred_yieldable, ) from synapse.util.logutils import log_function from synapse.util.metrics import Measure @@ -30,16 +34,8 @@ from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.types import get_domain_from_id - -from canonicaljson import encode_canonical_json -from collections import deque, namedtuple, OrderedDict -from functools import wraps - import synapse.metrics -import logging -import ujson as json - # these are only included to make the type annotations work from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 @@ -53,13 +49,25 @@ event_counter = metrics.register_counter( "persisted_events_sep", labels=["type", "origin_type", "origin_entity"] ) +# The number of times we are recalculating the current state +state_delta_counter = metrics.register_counter( + "state_delta", +) +# The number of times we are recalculating state when there is only a +# single forward extremity +state_delta_single_event_counter = metrics.register_counter( + "state_delta_single_event", +) +# The number of times we are reculating state when we could have resonably +# calculated the delta when we calculated the state for an event we were +# persisting. +state_delta_reuse_delta_counter = metrics.register_counter( + "state_delta_reuse_delta", +) + def encode_json(json_object): - if USE_FROZEN_DICTS: - # ujson doesn't like frozen_dicts - return encode_canonical_json(json_object) - else: - return json.dumps(json_object, ensure_ascii=False) + return frozendict_json_encoder.encode(json_object) class _EventPeristenceQueue(object): @@ -369,7 +377,8 @@ class EventsStore(EventsWorkerStore): room_id, ev_ctx_rm, latest_event_ids ) - if new_latest_event_ids == set(latest_event_ids): + latest_event_ids = set(latest_event_ids) + if new_latest_event_ids == latest_event_ids: # No change in extremities, so no change in state continue @@ -390,12 +399,34 @@ class EventsStore(EventsWorkerStore): if all_single_prev_not_state: continue + state_delta_counter.inc() + if len(new_latest_event_ids) == 1: + state_delta_single_event_counter.inc() + + # This is a fairly handwavey check to see if we could + # have guessed what the delta would have been when + # processing one of these events. + # What we're interested in is if the latest extremities + # were the same when we created the event as they are + # now. When this server creates a new event (as opposed + # to receiving it over federation) it will use the + # forward extremities as the prev_events, so we can + # guess this by looking at the prev_events and checking + # if they match the current forward extremities. + for ev, _ in ev_ctx_rm: + prev_event_ids = set(e for e, _ in ev.prev_events) + if latest_event_ids == prev_event_ids: + state_delta_reuse_delta_counter.inc() + break + logger.info( "Calculating state delta for room %s", room_id, ) current_state = yield self._get_new_state_after_events( room_id, - ev_ctx_rm, new_latest_event_ids, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, ) if current_state is not None: current_state_for_room[room_id] = current_state @@ -415,6 +446,9 @@ class EventsStore(EventsWorkerStore): new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc_by(len(chunk)) + synapse.metrics.event_persisted_position.set( + chunk[-1][0].internal_metadata.stream_ordering, + ) for event, context in chunk: if context.app_service: origin_type = "local" @@ -480,7 +514,8 @@ class EventsStore(EventsWorkerStore): defer.returnValue(new_latest_event_ids) @defer.inlineCallbacks - def _get_new_state_after_events(self, room_id, events_context, new_latest_event_ids): + def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids, + new_latest_event_ids): """Calculate the current state dict after adding some new events to a room @@ -491,6 +526,9 @@ class EventsStore(EventsWorkerStore): events_context (list[(EventBase, EventContext)]): events and contexts which are being added to the room + old_latest_event_ids (iterable[str]): + the old forward extremities for the room. + new_latest_event_ids (iterable[str]): the new forward extremities for the room. @@ -501,64 +539,89 @@ class EventsStore(EventsWorkerStore): """ if not new_latest_event_ids: - defer.returnValue({}) + return # map from state_group to ((type, key) -> event_id) state map - state_groups = {} - missing_event_ids = [] - was_updated = False + state_groups_map = {} + for ev, ctx in events_context: + if ctx.state_group is None: + # I don't think this can happen, but let's double-check + raise Exception( + "Context for new extremity event %s has no state " + "group" % (ev.event_id, ), + ) + + if ctx.state_group in state_groups_map: + continue + + state_groups_map[ctx.state_group] = ctx.current_state_ids + + # We need to map the event_ids to their state groups. First, let's + # check if the event is one we're persisting, in which case we can + # pull the state group from its context. + # Otherwise we need to pull the state group from the database. + + # Set of events we need to fetch groups for. (We know none of the old + # extremities are going to be in events_context). + missing_event_ids = set(old_latest_event_ids) + + event_id_to_state_group = {} for event_id in new_latest_event_ids: - # First search in the list of new events we're adding, - # and then use the current state from that + # First search in the list of new events we're adding. for ev, ctx in events_context: if event_id == ev.event_id: - if ctx.current_state_ids is None: - raise Exception("Unknown current state") - - if ctx.state_group is None: - # I don't think this can happen, but let's double-check - raise Exception( - "Context for new extremity event %s has no state " - "group" % (event_id, ), - ) - - # If we've already seen the state group don't bother adding - # it to the state sets again - if ctx.state_group not in state_groups: - state_groups[ctx.state_group] = ctx.current_state_ids - if ctx.delta_ids or hasattr(ev, "state_key"): - was_updated = True + event_id_to_state_group[event_id] = ctx.state_group break else: # If we couldn't find it, then we'll need to pull # the state from the database - was_updated = True - missing_event_ids.append(event_id) - - if not was_updated: - return + missing_event_ids.add(event_id) if missing_event_ids: - # Now pull out the state for any missing events from DB + # Now pull out the state groups for any missing events from DB event_to_groups = yield self._get_state_group_for_events( missing_event_ids, ) + event_id_to_state_group.update(event_to_groups) - groups = set(event_to_groups.itervalues()) - set(state_groups.iterkeys()) + # State groups of old_latest_event_ids + old_state_groups = set( + event_id_to_state_group[evid] for evid in old_latest_event_ids + ) - if groups: - group_to_state = yield self._get_state_for_groups(groups) - state_groups.update(group_to_state) + # State groups of new_latest_event_ids + new_state_groups = set( + event_id_to_state_group[evid] for evid in new_latest_event_ids + ) - if len(state_groups) == 1: + # If they old and new groups are the same then we don't need to do + # anything. + if old_state_groups == new_state_groups: + return + + # Now that we have calculated new_state_groups we need to get + # their state IDs so we can resolve to a single state set. + missing_state = new_state_groups - set(state_groups_map) + if missing_state: + group_to_state = yield self._get_state_for_groups(missing_state) + state_groups_map.update(group_to_state) + + if len(new_state_groups) == 1: # If there is only one state group, then we know what the current # state is. - defer.returnValue(state_groups.values()[0]) + defer.returnValue(state_groups_map[new_state_groups.pop()]) + + # Ok, we need to defer to the state handler to resolve our state sets. def get_events(ev_ids): return self.get_events( ev_ids, get_prev_content=False, check_redacted=False, ) + + state_groups = { + sg: state_groups_map[sg] for sg in new_state_groups + } + events_map = {ev.event_id: ev for ev, _ in events_context} logger.debug("calling resolve_state_groups from preserve_events") res = yield self._state_resolution_handler.resolve_state_groups( @@ -1288,13 +1351,49 @@ class EventsStore(EventsWorkerStore): defer.returnValue(set(r["event_id"] for r in rows)) - def have_events(self, event_ids): + @defer.inlineCallbacks + def have_seen_events(self, event_ids): """Given a list of event ids, check if we have already processed them. + Args: + event_ids (iterable[str]): + Returns: - dict: Has an entry for each event id we already have seen. Maps to - the rejected reason string if we rejected the event, else maps to - None. + Deferred[set[str]]: The events we have already seen. + """ + results = set() + + def have_seen_events_txn(txn, chunk): + sql = ( + "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" + % (",".join("?" * len(chunk)), ) + ) + txn.execute(sql, chunk) + for (event_id, ) in txn: + results.add(event_id) + + # break the input up into chunks of 100 + input_iterator = iter(event_ids) + for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), + []): + yield self.runInteraction( + "have_seen_events", + have_seen_events_txn, + chunk, + ) + defer.returnValue(results) + + def get_seen_events_with_rejections(self, event_ids): + """Given a list of event ids, check if we rejected them. + + Args: + event_ids (list[str]) + + Returns: + Deferred[dict[str, str|None): + Has an entry for each event id we already have seen. Maps to + the rejected reason string if we rejected the event, else maps + to None. """ if not event_ids: return defer.succeed({}) @@ -1316,9 +1415,7 @@ class EventsStore(EventsWorkerStore): return res - return self.runInteraction( - "have_events", f, - ) + return self.runInteraction("get_rejection_reasons", f) @defer.inlineCallbacks def count_daily_messages(self): diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 86c3b48ad4..ba834854e1 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -20,7 +20,7 @@ from synapse.events import FrozenEvent from synapse.events.utils import prune_event from synapse.util.logcontext import ( - preserve_fn, PreserveLoggingContext, make_deferred_yieldable + PreserveLoggingContext, make_deferred_yieldable, run_in_background, ) from synapse.util.metrics import Measure from synapse.api.errors import SynapseError @@ -28,7 +28,7 @@ from synapse.api.errors import SynapseError from collections import namedtuple import logging -import ujson as json +import simplejson as json # these are only included to make the type annotations work from synapse.events import EventBase # noqa: F401 @@ -51,6 +51,26 @@ _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) class EventsWorkerStore(SQLBaseStore): + def get_received_ts(self, event_id): + """Get received_ts (when it was persisted) for the event. + + Raises an exception for unknown events. + + Args: + event_id (str) + + Returns: + Deferred[int|None]: Timestamp in milliseconds, or None for events + that were persisted before received_ts was implemented. + """ + return self._simple_select_one_onecol( + table="events", + keyvalues={ + "event_id": event_id, + }, + retcol="received_ts", + desc="get_received_ts", + ) @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, @@ -299,7 +319,8 @@ class EventsWorkerStore(SQLBaseStore): res = yield make_deferred_yieldable(defer.gatherResults( [ - preserve_fn(self._get_event_from_row)( + run_in_background( + self._get_event_from_row, row["internal_metadata"], row["json"], row["redacts"], rejected_reason=row["rejects"], ) diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py index 8fde1aab8e..da05ccb027 100644 --- a/synapse/storage/group_server.py +++ b/synapse/storage/group_server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +20,7 @@ from synapse.api.errors import SynapseError from ._base import SQLBaseStore -import ujson as json +import simplejson as json # The category ID for the "default" category. We don't store as null in the @@ -29,6 +30,24 @@ _DEFAULT_ROLE_ID = "" class GroupServerStore(SQLBaseStore): + def set_group_join_policy(self, group_id, join_policy): + """Set the join policy of a group. + + join_policy can be one of: + * "invite" + * "open" + """ + return self._simple_update_one( + table="groups", + keyvalues={ + "group_id": group_id, + }, + updatevalues={ + "join_policy": join_policy, + }, + desc="set_group_join_policy", + ) + def get_group(self, group_id): return self._simple_select_one( table="groups", @@ -36,10 +55,11 @@ class GroupServerStore(SQLBaseStore): "group_id": group_id, }, retcols=( - "name", "short_description", "long_description", "avatar_url", "is_public" + "name", "short_description", "long_description", + "avatar_url", "is_public", "join_policy", ), allow_none=True, - desc="is_user_in_group", + desc="get_group", ) def get_users_in_group(self, group_id, include_private=False): diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index c845a0cec5..04411a665f 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,7 +26,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 47 +SCHEMA_VERSION = 48 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index eac8694e0f..63997ed449 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -23,7 +23,7 @@ from twisted.internet import defer import abc import logging -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index d809b2ba46..a50717db2d 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -22,6 +22,8 @@ from synapse.storage import background_updates from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from six.moves import range + class RegistrationWorkerStore(SQLBaseStore): @cached() @@ -460,18 +462,16 @@ class RegistrationStore(RegistrationWorkerStore, """ def _find_next_generated_user_id(txn): txn.execute("SELECT name FROM users") - rows = self.cursor_to_dict(txn) regex = re.compile("^@(\d+):") found = set() - for r in rows: - user_id = r["name"] + for user_id, in txn: match = regex.search(user_id) if match: found.add(int(match.group(1))) - for i in xrange(len(found) + 1): + for i in range(len(found) + 1): if i not in found: return i diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 34ed84ea22..ea6a189185 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -22,7 +22,7 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks import collections import logging -import ujson as json +import simplejson as json import re logger = logging.getLogger(__name__) @@ -530,7 +530,7 @@ class RoomStore(RoomWorkerStore, SearchStore): # Convert the IDs to MXC URIs for media_id in local_mxcs: - local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id)) + local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id)) for hostname, media_id in remote_mxcs: remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) @@ -594,7 +594,8 @@ class RoomStore(RoomWorkerStore, SearchStore): while next_token: sql = """ - SELECT stream_ordering, content FROM events + SELECT stream_ordering, json FROM events + JOIN event_json USING (room_id, event_id) WHERE room_id = ? AND stream_ordering < ? AND contains_url = ? AND outlier = ? @@ -606,8 +607,8 @@ class RoomStore(RoomWorkerStore, SearchStore): next_token = None for stream_ordering, content_json in txn: next_token = stream_ordering - content = json.loads(content_json) - + event_json = json.loads(content_json) + content = event_json["content"] content_url = content.get("url") thumbnail_url = content.get("info", {}).get("thumbnail_url") @@ -618,7 +619,7 @@ class RoomStore(RoomWorkerStore, SearchStore): if matches: hostname = matches.group(1) media_id = matches.group(2) - if hostname == self.hostname: + if hostname == self.hs.hostname: local_media_mxcs.append(media_id) else: remote_media_mxcs.append((hostname, media_id)) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 52e19e16b0..6a861943a2 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -28,7 +28,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.types import get_domain_from_id import logging -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) @@ -645,8 +645,9 @@ class RoomMemberStore(RoomMemberWorkerStore): def add_membership_profile_txn(txn): sql = (""" - SELECT stream_ordering, event_id, events.room_id, content + SELECT stream_ordering, event_id, events.room_id, event_json.json FROM events + INNER JOIN event_json USING (event_id) INNER JOIN room_memberships USING (event_id) WHERE ? <= stream_ordering AND stream_ordering < ? AND type = 'm.room.member' @@ -667,7 +668,8 @@ class RoomMemberStore(RoomMemberWorkerStore): event_id = row["event_id"] room_id = row["room_id"] try: - content = json.loads(row["content"]) + event_json = json.loads(row["json"]) + content = event_json['content'] except Exception: continue diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py index 8755bb2e49..4d725b92fe 100644 --- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py +++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging +import simplejson as json + logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index 4269ac69ad..e7351c3ae6 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -17,7 +17,7 @@ import logging from synapse.storage.prepare_database import get_statements from synapse.storage.engines import PostgresEngine, Sqlite3Engine -import ujson +import simplejson logger = logging.getLogger(__name__) @@ -66,7 +66,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py index 71b12a2731..6df57b5206 100644 --- a/synapse/storage/schema/delta/27/ts.py +++ b/synapse/storage/schema/delta/27/ts.py @@ -16,7 +16,7 @@ import logging from synapse.storage.prepare_database import get_statements -import ujson +import simplejson logger = logging.getLogger(__name__) @@ -45,7 +45,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index c53e53c94f..85bd1a2006 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -14,6 +14,8 @@ import logging from synapse.config.appservice import load_appservices +from six.moves import range + logger = logging.getLogger(__name__) @@ -58,7 +60,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): for as_id, user_ids in owned.items(): n = 100 - user_chunks = (user_ids[i:i + 100] for i in xrange(0, len(user_ids), n)) + user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n)) for chunk in user_chunks: cur.execute( database_engine.convert_param_style( diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py index 470ae0c005..fe6b7d196d 100644 --- a/synapse/storage/schema/delta/31/search_update.py +++ b/synapse/storage/schema/delta/31/search_update.py @@ -16,7 +16,7 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.prepare_database import get_statements import logging -import ujson +import simplejson logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ def run_create(cur, database_engine, *args, **kwargs): "rows_inserted": 0, "have_added_indexes": False, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py index 83066cccc9..1e002f9db2 100644 --- a/synapse/storage/schema/delta/33/event_fields.py +++ b/synapse/storage/schema/delta/33/event_fields.py @@ -15,7 +15,7 @@ from synapse.storage.prepare_database import get_statements import logging -import ujson +import simplejson logger = logging.getLogger(__name__) @@ -44,7 +44,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql new file mode 100644 index 0000000000..9248b0b24a --- /dev/null +++ b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('user_ips_last_seen_index', '{}'); diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py new file mode 100644 index 0000000000..2233af87d7 --- /dev/null +++ b/synapse/storage/schema/delta/48/group_unique_indexes.py @@ -0,0 +1,57 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + +FIX_INDEXES = """ +-- rebuild indexes as uniques +DROP INDEX groups_invites_g_idx; +CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id); +DROP INDEX groups_users_g_idx; +CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id); + +-- rename other indexes to actually match their table names.. +DROP INDEX groups_users_u_idx; +CREATE INDEX group_users_u_idx ON group_users(user_id); +DROP INDEX groups_invites_u_idx; +CREATE INDEX group_invites_u_idx ON group_invites(user_id); +DROP INDEX groups_rooms_g_idx; +CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id); +DROP INDEX groups_rooms_r_idx; +CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" + + # remove duplicates from group_users & group_invites tables + cur.execute(""" + DELETE FROM group_users WHERE %s NOT IN ( + SELECT min(%s) FROM group_users GROUP BY group_id, user_id + ); + """ % (rowid, rowid)) + cur.execute(""" + DELETE FROM group_invites WHERE %s NOT IN ( + SELECT min(%s) FROM group_invites GROUP BY group_id, user_id + ); + """ % (rowid, rowid)) + + for statement in get_statements(FIX_INDEXES.splitlines()): + cur.execute(statement) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/schema/delta/48/groups_joinable.sql new file mode 100644 index 0000000000..ce26eaf0c9 --- /dev/null +++ b/synapse/storage/schema/delta/48/groups_joinable.sql @@ -0,0 +1,22 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This isn't a real ENUM because sqlite doesn't support it + * and we use a default of NULL for inserted rows and interpret + * NULL at the python store level as necessary so that existing + * rows are given the correct default policy. + */ +ALTER TABLE groups ADD COLUMN join_policy TEXT NOT NULL DEFAULT 'invite'; diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 2755acff40..6ba3e59889 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -16,7 +16,7 @@ from collections import namedtuple import logging import re -import ujson as json +import simplejson as json from twisted.internet import defer @@ -75,8 +75,9 @@ class SearchStore(BackgroundUpdateStore): def reindex_search_txn(txn): sql = ( - "SELECT stream_ordering, event_id, room_id, type, content, " + "SELECT stream_ordering, event_id, room_id, type, json, " " origin_server_ts FROM events" + " JOIN event_json USING (room_id, event_id)" " WHERE ? <= stream_ordering AND stream_ordering < ?" " AND (%s)" " ORDER BY stream_ordering DESC" @@ -104,7 +105,8 @@ class SearchStore(BackgroundUpdateStore): stream_ordering = row["stream_ordering"] origin_server_ts = row["origin_server_ts"] try: - content = json.loads(row["content"]) + event_json = json.loads(row["json"]) + content = event_json["content"] except Exception: continue diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 2956c3b3e0..f0784ba137 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -41,12 +41,14 @@ from synapse.storage.events import EventsWorkerStore from synapse.util.caches.descriptors import cached from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache -from synapse.util.logcontext import make_deferred_yieldable, preserve_fn +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.storage.engines import PostgresEngine, Sqlite3Engine import abc import logging +from six.moves import range + logger = logging.getLogger(__name__) @@ -196,13 +198,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): results = {} room_ids = list(room_ids) - for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): + for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)): res = yield make_deferred_yieldable(defer.gatherResults([ - preserve_fn(self.get_room_events_stream_for_room)( + run_in_background( + self.get_room_events_stream_for_room, room_id, from_key, to_key, limit, order=order, ) for room_id in rm_ids - ])) + ], consumeErrors=True)) results.update(dict(zip(rm_ids, res))) defer.returnValue(results) diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index fc46bf7bb3..6671d3cfca 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -19,9 +19,11 @@ from synapse.storage.account_data import AccountDataWorkerStore from synapse.util.caches.descriptors import cached from twisted.internet import defer -import ujson as json +import simplejson as json import logging +from six.moves import range + logger = logging.getLogger(__name__) @@ -98,7 +100,7 @@ class TagsWorkerStore(AccountDataWorkerStore): batch_size = 50 results = [] - for i in xrange(0, len(tag_ids), batch_size): + for i in range(0, len(tag_ids), batch_size): tags = yield self.runInteraction( "get_all_updated_tag_content", get_tag_content, diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 8f61f7ffae..f825264ea9 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -23,7 +23,7 @@ from canonicaljson import encode_canonical_json from collections import namedtuple import logging -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index dfdcbb3181..d6e289ffbe 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -667,7 +667,7 @@ class UserDirectoryStore(SQLBaseStore): # The array of numbers are the weights for the various part of the # search: (domain, _, display name, localpart) sql = """ - SELECT d.user_id, display_name, avatar_url + SELECT d.user_id AS user_id, display_name, avatar_url FROM user_directory_search INNER JOIN user_directory AS d USING (user_id) %s @@ -702,7 +702,7 @@ class UserDirectoryStore(SQLBaseStore): search_query = _parse_query_sqlite(search_term) sql = """ - SELECT d.user_id, display_name, avatar_url + SELECT d.user_id AS user_id, display_name, avatar_url FROM user_directory_search INNER JOIN user_directory AS d USING (user_id) %s diff --git a/synapse/types.py b/synapse/types.py index 7cb24cecb2..cc7c182a78 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -169,7 +169,7 @@ class DomainSpecificString( except Exception: return False - __str__ = to_string + __repr__ = to_string class UserID(DomainSpecificString): diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 756d8ffa32..814a7bf71b 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.errors import SynapseError from synapse.util.logcontext import PreserveLoggingContext from twisted.internet import defer, reactor, task @@ -24,11 +23,6 @@ import logging logger = logging.getLogger(__name__) -class DeferredTimedOutError(SynapseError): - def __init__(self): - super(DeferredTimedOutError, self).__init__(504, "Timed out") - - def unwrapFirstError(failure): # defer.gatherResults and DeferredLists wrap failures. failure.trap(defer.FirstError) @@ -85,53 +79,3 @@ class Clock(object): except Exception: if not ignore_errs: raise - - def time_bound_deferred(self, given_deferred, time_out): - if given_deferred.called: - return given_deferred - - ret_deferred = defer.Deferred() - - def timed_out_fn(): - e = DeferredTimedOutError() - - try: - ret_deferred.errback(e) - except Exception: - pass - - try: - given_deferred.cancel() - except Exception: - pass - - timer = None - - def cancel(res): - try: - self.cancel_call_later(timer) - except Exception: - pass - return res - - ret_deferred.addBoth(cancel) - - def success(res): - try: - ret_deferred.callback(res) - except Exception: - pass - - return res - - def err(res): - try: - ret_deferred.errback(res) - except Exception: - pass - - given_deferred.addCallbacks(callback=success, errback=err) - - timer = self.call_later(time_out, timed_out_fn) - - return ret_deferred diff --git a/synapse/util/async.py b/synapse/util/async.py index 0729bb2863..9dd4e6b5bc 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -15,9 +15,11 @@ from twisted.internet import defer, reactor +from twisted.internet.defer import CancelledError +from twisted.python import failure from .logcontext import ( - PreserveLoggingContext, make_deferred_yieldable, preserve_fn + PreserveLoggingContext, make_deferred_yieldable, run_in_background ) from synapse.util import logcontext, unwrapFirstError @@ -25,6 +27,8 @@ from contextlib import contextmanager import logging +from six.moves import range + logger = logging.getLogger(__name__) @@ -156,13 +160,13 @@ def concurrently_execute(func, args, limit): def _concurrently_execute_inner(): try: while True: - yield func(it.next()) + yield func(next(it)) except StopIteration: pass return logcontext.make_deferred_yieldable(defer.gatherResults([ - preserve_fn(_concurrently_execute_inner)() - for _ in xrange(limit) + run_in_background(_concurrently_execute_inner) + for _ in range(limit) ], consumeErrors=True)).addErrback(unwrapFirstError) @@ -392,3 +396,68 @@ class ReadWriteLock(object): self.key_to_current_writer.pop(key) defer.returnValue(_ctx_manager()) + + +class DeferredTimeoutError(Exception): + """ + This error is raised by default when a L{Deferred} times out. + """ + + +def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None): + """ + Add a timeout to a deferred by scheduling it to be cancelled after + timeout seconds. + + This is essentially a backport of deferred.addTimeout, which was introduced + in twisted 16.5. + + If the deferred gets timed out, it errbacks with a DeferredTimeoutError, + unless a cancelable function was passed to its initialization or unless + a different on_timeout_cancel callable is provided. + + Args: + deferred (defer.Deferred): deferred to be timed out + timeout (Number): seconds to time out after + + on_timeout_cancel (callable): A callable which is called immediately + after the deferred times out, and not if this deferred is + otherwise cancelled before the timeout. + + It takes an arbitrary value, which is the value of the deferred at + that exact point in time (probably a CancelledError Failure), and + the timeout. + + The default callable (if none is provided) will translate a + CancelledError Failure into a DeferredTimeoutError. + """ + timed_out = [False] + + def time_it_out(): + timed_out[0] = True + deferred.cancel() + + delayed_call = reactor.callLater(timeout, time_it_out) + + def convert_cancelled(value): + if timed_out[0]: + to_call = on_timeout_cancel or _cancelled_to_timed_out_error + return to_call(value, timeout) + return value + + deferred.addBoth(convert_cancelled) + + def cancel_timeout(result): + # stop the pending call to cancel the deferred if it's been fired + if delayed_call.active(): + delayed_call.cancel() + return result + + deferred.addBoth(cancel_timeout) + + +def _cancelled_to_timed_out_error(value, timeout): + if isinstance(value, failure.Failure): + value.trap(CancelledError) + raise DeferredTimeoutError(timeout, "Deferred") + return value diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index bf3a66eae4..68285a7594 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,12 +40,11 @@ _CacheSentinel = object() class CacheEntry(object): __slots__ = [ - "deferred", "sequence", "callbacks", "invalidated" + "deferred", "callbacks", "invalidated" ] - def __init__(self, deferred, sequence, callbacks): + def __init__(self, deferred, callbacks): self.deferred = deferred - self.sequence = sequence self.callbacks = set(callbacks) self.invalidated = False @@ -62,7 +62,6 @@ class Cache(object): "max_entries", "name", "keylen", - "sequence", "thread", "metrics", "_pending_deferred_cache", @@ -80,7 +79,6 @@ class Cache(object): self.name = name self.keylen = keylen - self.sequence = 0 self.thread = None self.metrics = register_cache(name, self.cache) @@ -113,11 +111,10 @@ class Cache(object): callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _CacheSentinel) if val is not _CacheSentinel: - if val.sequence == self.sequence: - val.callbacks.update(callbacks) - if update_metrics: - self.metrics.inc_hits() - return val.deferred + val.callbacks.update(callbacks) + if update_metrics: + self.metrics.inc_hits() + return val.deferred val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) if val is not _CacheSentinel: @@ -137,12 +134,9 @@ class Cache(object): self.check_thread() entry = CacheEntry( deferred=value, - sequence=self.sequence, callbacks=callbacks, ) - entry.callbacks.update(callbacks) - existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry: existing_entry.invalidate() @@ -150,13 +144,25 @@ class Cache(object): self._pending_deferred_cache[key] = entry def shuffle(result): - if self.sequence == entry.sequence: - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - self.cache.set(key, result, entry.callbacks) - else: - entry.invalidate() + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + self.cache.set(key, result, entry.callbacks) else: + # oops, the _pending_deferred_cache has been updated since + # we started our query, so we are out of date. + # + # Better put back whatever we took out. (We do it this way + # round, rather than peeking into the _pending_deferred_cache + # and then removing on a match, to make the common case faster) + if existing_entry is not None: + self._pending_deferred_cache[key] = existing_entry + + # we're not going to put this entry into the cache, so need + # to make sure that the invalidation callbacks are called. + # That was probably done when _pending_deferred_cache was + # updated, but it's possible that `set` was called without + # `invalidate` being previously called, in which case it may + # not have been. Either way, let's double-check now. entry.invalidate() return result @@ -168,25 +174,29 @@ class Cache(object): def invalidate(self, key): self.check_thread() + self.cache.pop(key, None) - # Increment the sequence number so that any SELECT statements that - # raced with the INSERT don't update the cache (SYN-369) - self.sequence += 1 + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, which will (a) stop it being returned + # for future queries and (b) stop it being persisted as a proper entry + # in self.cache. entry = self._pending_deferred_cache.pop(key, None) + + # run the invalidation callbacks now, rather than waiting for the + # deferred to resolve. if entry: entry.invalidate() - self.cache.pop(key, None) - def invalidate_many(self, key): self.check_thread() if not isinstance(key, tuple): raise TypeError( "The cache key must be a tuple not %r" % (type(key),) ) - self.sequence += 1 self.cache.del_multi(key) + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, as above entry_dict = self._pending_deferred_cache.pop(key, None) if entry_dict is not None: for entry in iterate_tree_cache_entry(entry_dict): @@ -194,8 +204,10 @@ class Cache(object): def invalidate_all(self): self.check_thread() - self.sequence += 1 self.cache.clear() + for entry in self._pending_deferred_cache.itervalues(): + entry.invalidate() + self._pending_deferred_cache.clear() class _CacheDescriptorBase(object): diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index d4105822b3..1709e8b429 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -132,9 +132,13 @@ class DictionaryCache(object): self._update_or_insert(key, value, known_absent) def _update_or_insert(self, key, value, known_absent): - entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {})) + # We pop and reinsert as we need to tell the cache the size may have + # changed + + entry = self.cache.pop(key, DictionaryEntry(False, set(), {})) entry.value.update(value) entry.known_absent.update(known_absent) + self.cache[key] = entry def _insert(self, key, value, known_absent): self.cache[key] = DictionaryEntry(True, known_absent, value) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index f088dd430e..1c5a982094 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -154,14 +154,21 @@ class LruCache(object): def cache_set(key, value, callbacks=[]): node = cache.get(key, None) if node is not None: - if value != node.value: + # We sometimes store large objects, e.g. dicts, which cause + # the inequality check to take a long time. So let's only do + # the check if we have some callbacks to call. + if node.callbacks and value != node.value: for cb in node.callbacks: cb() node.callbacks.clear() - if size_callback: - cached_cache_len[0] -= size_callback(node.value) - cached_cache_len[0] += size_callback(value) + # We don't bother to protect this by value != node.value as + # generally size_callback will be cheap compared with equality + # checks. (For example, taking the size of two dicts is quicker + # than comparing them for equality.) + if size_callback: + cached_cache_len[0] -= size_callback(node.value) + cached_cache_len[0] += size_callback(value) node.callbacks.update(callbacks) diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 00af539880..7f79333e96 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -12,8 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + +from twisted.internet import defer from synapse.util.async import ObservableDeferred +from synapse.util.caches import metrics as cache_metrics +from synapse.util.logcontext import make_deferred_yieldable, run_in_background + +logger = logging.getLogger(__name__) class ResponseCache(object): @@ -24,20 +31,68 @@ class ResponseCache(object): used rather than trying to compute a new response. """ - def __init__(self, hs, timeout_ms=0): + def __init__(self, hs, name, timeout_ms=0): self.pending_result_cache = {} # Requests that haven't finished yet. self.clock = hs.get_clock() self.timeout_sec = timeout_ms / 1000. + self._name = name + self._metrics = cache_metrics.register_cache( + "response_cache", + size_callback=lambda: self.size(), + cache_name=name, + ) + + def size(self): + return len(self.pending_result_cache) + def get(self, key): + """Look up the given key. + + Can return either a new Deferred (which also doesn't follow the synapse + logcontext rules), or, if the request has completed, the actual + result. You will probably want to make_deferred_yieldable the result. + + If there is no entry for the key, returns None. It is worth noting that + this means there is no way to distinguish a completed result of None + from an absent cache entry. + + Args: + key (hashable): + + Returns: + twisted.internet.defer.Deferred|None|E: None if there is no entry + for this key; otherwise either a deferred result or the result + itself. + """ result = self.pending_result_cache.get(key) if result is not None: + self._metrics.inc_hits() return result.observe() else: + self._metrics.inc_misses() return None def set(self, key, deferred): + """Set the entry for the given key to the given deferred. + + *deferred* should run its callbacks in the sentinel logcontext (ie, + you should wrap normal synapse deferreds with + logcontext.run_in_background). + + Can return either a new Deferred (which also doesn't follow the synapse + logcontext rules), or, if *deferred* was already complete, the actual + result. You will probably want to make_deferred_yieldable the result. + + Args: + key (hashable): + deferred (twisted.internet.defer.Deferred[T): + + Returns: + twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual + result. + """ result = ObservableDeferred(deferred, consumeErrors=True) self.pending_result_cache[key] = result @@ -53,3 +108,52 @@ class ResponseCache(object): result.addBoth(remove) return result.observe() + + def wrap(self, key, callback, *args, **kwargs): + """Wrap together a *get* and *set* call, taking care of logcontexts + + First looks up the key in the cache, and if it is present makes it + follow the synapse logcontext rules and returns it. + + Otherwise, makes a call to *callback(*args, **kwargs)*, which should + follow the synapse logcontext rules, and adds the result to the cache. + + Example usage: + + @defer.inlineCallbacks + def handle_request(request): + # etc + defer.returnValue(result) + + result = yield response_cache.wrap( + key, + handle_request, + request, + ) + + Args: + key (hashable): key to get/set in the cache + + callback (callable): function to call if the key is not found in + the cache + + *args: positional parameters to pass to the callback, if it is used + + **kwargs: named paramters to pass to the callback, if it is used + + Returns: + twisted.internet.defer.Deferred: yieldable result + """ + result = self.get(key) + if not result: + logger.info("[%s]: no cached result for [%s], calculating new one", + self._name, key) + d = run_in_background(callback, *args, **kwargs) + result = self.set(key, d) + elif not isinstance(result, defer.Deferred) or result.called: + logger.info("[%s]: using completed cached result for [%s]", + self._name, key) + else: + logger.info("[%s]: using incomplete cached result for [%s]", + self._name, key) + return make_deferred_yieldable(result) diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 90a2608d6f..3380970e4e 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -15,9 +15,9 @@ from twisted.internet import threads, reactor -from synapse.util.logcontext import make_deferred_yieldable, preserve_fn +from synapse.util.logcontext import make_deferred_yieldable, run_in_background -import Queue +from six.moves import queue class BackgroundFileConsumer(object): @@ -49,7 +49,7 @@ class BackgroundFileConsumer(object): # Queue of slices of bytes to be written. When producer calls # unregister a final None is sent. - self._bytes_queue = Queue.Queue() + self._bytes_queue = queue.Queue() # Deferred that is resolved when finished writing self._finished_deferred = None @@ -70,7 +70,9 @@ class BackgroundFileConsumer(object): self._producer = producer self.streaming = streaming - self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer) + self._finished_deferred = run_in_background( + threads.deferToThread, self._writer + ) if not streaming: self._producer.resumeProducing() diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 6322f0f55c..f497b51f4a 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -14,6 +14,7 @@ # limitations under the License. from frozendict import frozendict +import simplejson as json def freeze(o): @@ -49,3 +50,21 @@ def unfreeze(o): pass return o + + +def _handle_frozendict(obj): + """Helper for EventEncoder. Makes frozendicts serializable by returning + the underlying dict + """ + if type(obj) is frozendict: + # fishing the protected dict out of the object is a bit nasty, + # but we don't really want the overhead of copying the dict. + return obj._dict + raise TypeError('Object of type %s is not JSON serializable' % + obj.__class__.__name__) + + +# A JSONEncoder which is capable of encoding frozendics without barfing +frozendict_json_encoder = json.JSONEncoder( + default=_handle_frozendict, +) diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index 45be47159a..e9f0f292ee 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.web.resource import Resource +from twisted.web.resource import NoResource import logging @@ -40,12 +40,15 @@ def create_resource_tree(desired_tree, root_resource): # extra resources to existing nodes. See self._resource_id for the key. resource_mappings = {} for full_path, res in desired_tree.items(): + # twisted requires all resources to be bytes + full_path = full_path.encode("utf-8") + logger.info("Attaching %s to path %s", res, full_path) last_resource = root_resource - for path_seg in full_path.split('/')[1:-1]: + for path_seg in full_path.split(b'/')[1:-1]: if path_seg not in last_resource.listNames(): # resource doesn't exist, so make a "dummy resource" - child_resource = Resource() + child_resource = NoResource() last_resource.putChild(path_seg, child_resource) res_id = _resource_id(last_resource, path_seg) resource_mappings[res_id] = child_resource @@ -57,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource): # =========================== # now attach the actual desired resource - last_path_seg = full_path.split('/')[-1] + last_path_seg = full_path.split(b'/')[-1] # if there is already a resource here, thieve its children and # replace it diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index fdd4a93d07..eab9d57650 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -92,6 +92,7 @@ class LoggingContext(object): def __nonzero__(self): return False + __bool__ = __nonzero__ # python3 sentinel = Sentinel() @@ -301,31 +302,49 @@ def preserve_fn(f): def run_in_background(f, *args, **kwargs): """Calls a function, ensuring that the current context is restored after return from the function, and that the sentinel context is set once the - deferred returned by the funtion completes. + deferred returned by the function completes. Useful for wrapping functions that return a deferred which you don't yield - on. + on (for instance because you want to pass it to deferred.gatherResults()). + + Note that if you completely discard the result, you should make sure that + `f` doesn't raise any deferred exceptions, otherwise a scary-looking + CRITICAL error about an unhandled error will be logged without much + indication about where it came from. """ current = LoggingContext.current_context() - res = f(*args, **kwargs) - if isinstance(res, defer.Deferred) and not res.called: - # The function will have reset the context before returning, so - # we need to restore it now. - LoggingContext.set_current_context(current) - - # The original context will be restored when the deferred - # completes, but there is nothing waiting for it, so it will - # get leaked into the reactor or some other function which - # wasn't expecting it. We therefore need to reset the context - # here. - # - # (If this feels asymmetric, consider it this way: we are - # effectively forking a new thread of execution. We are - # probably currently within a ``with LoggingContext()`` block, - # which is supposed to have a single entry and exit point. But - # by spawning off another deferred, we are effectively - # adding a new exit point.) - res.addBoth(_set_context_cb, LoggingContext.sentinel) + try: + res = f(*args, **kwargs) + except: # noqa: E722 + # the assumption here is that the caller doesn't want to be disturbed + # by synchronous exceptions, so let's turn them into Failures. + return defer.fail() + + if not isinstance(res, defer.Deferred): + return res + + if res.called and not res.paused: + # The function should have maintained the logcontext, so we can + # optimise out the messing about + return res + + # The function may have reset the context before returning, so + # we need to restore it now. + ctx = LoggingContext.set_current_context(current) + + # The original context will be restored when the deferred + # completes, but there is nothing waiting for it, so it will + # get leaked into the reactor or some other function which + # wasn't expecting it. We therefore need to reset the context + # here. + # + # (If this feels asymmetric, consider it this way: we are + # effectively forking a new thread of execution. We are + # probably currently within a ``with LoggingContext()`` block, + # which is supposed to have a single entry and exit point. But + # by spawning off another deferred, we are effectively + # adding a new exit point.) + res.addBoth(_set_context_cb, ctx) return res @@ -340,11 +359,20 @@ def make_deferred_yieldable(deferred): returning a deferred. Then, when the deferred completes, restores the current logcontext before running callbacks/errbacks. - (This is more-or-less the opposite operation to preserve_fn.) + (This is more-or-less the opposite operation to run_in_background.) """ - if isinstance(deferred, defer.Deferred) and not deferred.called: - prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) - deferred.addBoth(_set_context_cb, prev_context) + if not isinstance(deferred, defer.Deferred): + return deferred + + if deferred.called and not deferred.paused: + # it looks like this deferred is ready to run any callbacks we give it + # immediately. We may as well optimise out the logcontext faffery. + return deferred + + # ok, we can't be sure that a yield won't block, so let's reset the + # logcontext, and add a callback to the deferred to restore it. + prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) + deferred.addBoth(_set_context_cb, prev_context) return deferred diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py index cdbc4bffd7..3e42868ea9 100644 --- a/synapse/util/logformatter.py +++ b/synapse/util/logformatter.py @@ -14,7 +14,7 @@ # limitations under the License. -import StringIO +from six import StringIO import logging import traceback @@ -32,7 +32,7 @@ class LogFormatter(logging.Formatter): super(LogFormatter, self).__init__(*args, **kwargs) def formatException(self, ei): - sio = StringIO.StringIO() + sio = StringIO() (typ, val, tb) = ei # log the stack above the exception capture point if possible, but diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 1101881a2d..0ab63c3d7d 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -18,7 +18,10 @@ from twisted.internet import defer from synapse.api.errors import LimitExceededError from synapse.util.async import sleep -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import ( + run_in_background, make_deferred_yieldable, + PreserveLoggingContext, +) import collections import contextlib @@ -150,7 +153,7 @@ class _PerHostRatelimiter(object): "Ratelimit [%s]: sleeping req", id(request_id), ) - ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0) + ret_defer = run_in_background(sleep, self.sleep_msec / 1000.0) self.sleeping_requests.add(request_id) @@ -176,6 +179,9 @@ class _PerHostRatelimiter(object): return r def on_err(r): + # XXX: why is this necessary? this is called before we start + # processing the request so why would the request be in + # current_processing? self.current_processing.discard(request_id) return r @@ -187,7 +193,7 @@ class _PerHostRatelimiter(object): ret_defer.addCallbacks(on_start, on_err) ret_defer.addBoth(on_both) - return ret_defer + return make_deferred_yieldable(ret_defer) def _on_exit(self, request_id): logger.debug( @@ -197,7 +203,12 @@ class _PerHostRatelimiter(object): self.current_processing.discard(request_id) try: request_id, deferred = self.ready_request_queue.popitem() + + # XXX: why do we do the following? the on_start callback above will + # do it for us. self.current_processing.add(request_id) - deferred.callback(None) + + with PreserveLoggingContext(): + deferred.callback(None) except KeyError: pass diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 47b0bb5eb3..4e93f69d3a 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -203,8 +203,8 @@ class RetryDestinationLimiter(object): ) except Exception: logger.exception( - "Failed to store set_destination_retry_timings", + "Failed to store destination_retry_timings", ) # we deliberately do this in the background. - synapse.util.logcontext.preserve_fn(store_retry_timings)() + synapse.util.logcontext.run_in_background(store_retry_timings) diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 95a6168e16..b98b9dc6e4 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -15,6 +15,7 @@ import random import string +from six.moves import range _string_with_symbols = ( string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" @@ -22,12 +23,12 @@ _string_with_symbols = ( def random_string(length): - return ''.join(random.choice(string.ascii_letters) for _ in xrange(length)) + return ''.join(random.choice(string.ascii_letters) for _ in range(length)) def random_string_with_symbols(length): return ''.join( - random.choice(_string_with_symbols) for _ in xrange(length) + random.choice(_string_with_symbols) for _ in range(length) ) diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index b70f9a6b0a..7a9e45aca9 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from six.moves import range + class _Entry(object): __slots__ = ["end_key", "queue"] @@ -68,7 +70,7 @@ class WheelTimer(object): # Add empty entries between the end of the current list and when we want # to insert. This ensures there are no gaps. self.entries.extend( - _Entry(key) for key in xrange(last_key, then_key + 1) + _Entry(key) for key in range(last_key, then_key + 1) ) self.entries[-1].queue.append(obj) |