diff options
author | Matthew Hodgson <matthew@matrix.org> | 2018-03-13 22:11:58 +0000 |
---|---|---|
committer | Matthew Hodgson <matthew@matrix.org> | 2018-03-13 22:11:58 +0000 |
commit | 12350e3f9aeac73f80d1c5bb35ef9e5d0887dec0 (patch) | |
tree | 1db11e8ee3cd835de2c8d33aaf9534573d0ac357 /synapse | |
parent | ensure we always include the members for a given timeline block (diff) | |
parent | typoe (diff) | |
download | synapse-12350e3f9aeac73f80d1c5bb35ef9e5d0887dec0.tar.xz |
merge proper fix to bug 2969
Diffstat (limited to '')
34 files changed, 515 insertions, 301 deletions
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index 3b3352798d..0a8ce9bc66 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -156,7 +156,6 @@ def start(config_options): ) ss.setup() - ss.get_handlers() ss.start_listening(config.worker_listeners) def start(): diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index fc0b9e8c04..eb593c5278 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -161,7 +161,6 @@ def start(config_options): ) ss.setup() - ss.get_handlers() ss.start_listening(config.worker_listeners) def start(): diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 4de43c41f0..20d157911b 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -144,7 +144,6 @@ def start(config_options): ) ss.setup() - ss.get_handlers() ss.start_listening(config.worker_listeners) def start(): diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index e32ee8fe93..816c080d18 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -211,7 +211,6 @@ def start(config_options): ) ss.setup() - ss.get_handlers() ss.start_listening(config.worker_listeners) def start(): diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index e375f2bbcf..e477c7ced6 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -348,7 +348,7 @@ def setup(config_options): hs.get_state_handler().start_caching() hs.get_datastore().start_profiling() hs.get_datastore().start_doing_background_updates() - hs.get_replication_layer().start_get_pdu_cache() + hs.get_federation_client().start_get_pdu_cache() register_memory_metrics(hs) diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 1ed1ca8772..84c5791b3b 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -158,7 +158,6 @@ def start(config_options): ) ss.setup() - ss.get_handlers() ss.start_listening(config.worker_listeners) def start(): diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py index 2e32d245ba..f5f0bdfca3 100644 --- a/synapse/federation/__init__.py +++ b/synapse/federation/__init__.py @@ -15,11 +15,3 @@ """ This package includes all the federation specific logic. """ - -from .replication import ReplicationLayer - - -def initialize_http_replication(hs): - transport = hs.get_federation_transport_client() - - return ReplicationLayer(hs, transport) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 7918d3e442..79eaa31031 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -27,7 +27,13 @@ logger = logging.getLogger(__name__) class FederationBase(object): def __init__(self, hs): + self.hs = hs + + self.server_name = hs.hostname + self.keyring = hs.get_keyring() self.spam_checker = hs.get_spam_checker() + self.store = hs.get_datastore() + self._clock = hs.get_clock() @defer.inlineCallbacks def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 813907f7f2..38440da5b5 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -58,6 +58,7 @@ class FederationClient(FederationBase): self._clear_tried_cache, 60 * 1000, ) self.state = hs.get_state_handler() + self.transport_layer = hs.get_federation_transport_client() def _clear_tried_cache(self): """Clear pdu_destination_tried cache""" diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 9849953c9b..bea7fd0b71 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -17,12 +17,14 @@ import logging import simplejson as json from twisted.internet import defer -from synapse.api.errors import AuthError, FederationError, SynapseError +from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError from synapse.crypto.event_signing import compute_event_signature from synapse.federation.federation_base import ( FederationBase, event_from_pdu_json, ) + +from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction import synapse.metrics from synapse.types import get_domain_from_id @@ -52,50 +54,19 @@ class FederationServer(FederationBase): super(FederationServer, self).__init__(hs) self.auth = hs.get_auth() + self.handler = hs.get_handlers().federation_handler self._server_linearizer = async.Linearizer("fed_server") self._transaction_linearizer = async.Linearizer("fed_txn_handler") + self.transaction_actions = TransactionActions(self.store) + + self.registry = hs.get_federation_registry() + # We cache responses to state queries, as they take a while and often # come in waves. self._state_resp_cache = ResponseCache(hs, timeout_ms=30000) - def set_handler(self, handler): - """Sets the handler that the replication layer will use to communicate - receipt of new PDUs from other home servers. The required methods are - documented on :py:class:`.ReplicationHandler`. - """ - self.handler = handler - - def register_edu_handler(self, edu_type, handler): - if edu_type in self.edu_handlers: - raise KeyError("Already have an EDU handler for %s" % (edu_type,)) - - self.edu_handlers[edu_type] = handler - - def register_query_handler(self, query_type, handler): - """Sets the handler callable that will be used to handle an incoming - federation Query of the given type. - - Args: - query_type (str): Category name of the query, which should match - the string used by make_query. - handler (callable): Invoked to handle incoming queries of this type - - handler is invoked as: - result = handler(args) - - where 'args' is a dict mapping strings to strings of the query - arguments. It should return a Deferred that will eventually yield an - object to encode as JSON. - """ - if query_type in self.query_handlers: - raise KeyError( - "Already have a Query handler for %s" % (query_type,) - ) - - self.query_handlers[query_type] = handler - @defer.inlineCallbacks @log_function def on_backfill_request(self, origin, room_id, versions, limit): @@ -229,16 +200,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def received_edu(self, origin, edu_type, content): received_edus_counter.inc() - - if edu_type in self.edu_handlers: - try: - yield self.edu_handlers[edu_type](origin, content) - except SynapseError as e: - logger.info("Failed to handle edu %r: %r", edu_type, e) - except Exception as e: - logger.exception("Failed to handle edu %r", edu_type) - else: - logger.warn("Received EDU of type %s with no handler", edu_type) + yield self.registry.on_edu(edu_type, origin, content) @defer.inlineCallbacks @log_function @@ -328,14 +290,8 @@ class FederationServer(FederationBase): @defer.inlineCallbacks def on_query_request(self, query_type, args): received_queries_counter.inc(query_type) - - if query_type in self.query_handlers: - response = yield self.query_handlers[query_type](args) - defer.returnValue((200, response)) - else: - defer.returnValue( - (404, "No handler for Query type '%s'" % (query_type,)) - ) + resp = yield self.registry.on_query(query_type, args) + defer.returnValue((200, resp)) @defer.inlineCallbacks def on_make_join_request(self, room_id, user_id): @@ -607,3 +563,66 @@ class FederationServer(FederationBase): origin, room_id, event_dict ) defer.returnValue(ret) + + +class FederationHandlerRegistry(object): + """Allows classes to register themselves as handlers for a given EDU or + query type for incoming federation traffic. + """ + def __init__(self): + self.edu_handlers = {} + self.query_handlers = {} + + def register_edu_handler(self, edu_type, handler): + """Sets the handler callable that will be used to handle an incoming + federation EDU of the given type. + + Args: + edu_type (str): The type of the incoming EDU to register handler for + handler (Callable[[str, dict]]): A callable invoked on incoming EDU + of the given type. The arguments are the origin server name and + the EDU contents. + """ + if edu_type in self.edu_handlers: + raise KeyError("Already have an EDU handler for %s" % (edu_type,)) + + self.edu_handlers[edu_type] = handler + + def register_query_handler(self, query_type, handler): + """Sets the handler callable that will be used to handle an incoming + federation query of the given type. + + Args: + query_type (str): Category name of the query, which should match + the string used by make_query. + handler (Callable[[dict], Deferred[dict]]): Invoked to handle + incoming queries of this type. The return will be yielded + on and the result used as the response to the query request. + """ + if query_type in self.query_handlers: + raise KeyError( + "Already have a Query handler for %s" % (query_type,) + ) + + self.query_handlers[query_type] = handler + + @defer.inlineCallbacks + def on_edu(self, edu_type, origin, content): + handler = self.edu_handlers.get(edu_type) + if not handler: + logger.warn("No handler registered for EDU type %s", edu_type) + + try: + yield handler(origin, content) + except SynapseError as e: + logger.info("Failed to handle edu %r: %r", edu_type, e) + except Exception as e: + logger.exception("Failed to handle edu %r", edu_type) + + def on_query(self, query_type, args): + handler = self.query_handlers.get(query_type) + if not handler: + logger.warn("No handler registered for query type %s", query_type) + raise NotFoundError("No handler for Query type '%s'" % (query_type,)) + + return handler(args) diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py deleted file mode 100644 index 62d865ec4b..0000000000 --- a/synapse/federation/replication.py +++ /dev/null @@ -1,73 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This layer is responsible for replicating with remote home servers using -a given transport. -""" - -from .federation_client import FederationClient -from .federation_server import FederationServer - -from .persistence import TransactionActions - -import logging - - -logger = logging.getLogger(__name__) - - -class ReplicationLayer(FederationClient, FederationServer): - """This layer is responsible for replicating with remote home servers over - the given transport. I.e., does the sending and receiving of PDUs to - remote home servers. - - The layer communicates with the rest of the server via a registered - ReplicationHandler. - - In more detail, the layer: - * Receives incoming data and processes it into transactions and pdus. - * Fetches any PDUs it thinks it might have missed. - * Keeps the current state for contexts up to date by applying the - suitable conflict resolution. - * Sends outgoing pdus wrapped in transactions. - * Fills out the references to previous pdus/transactions appropriately - for outgoing data. - """ - - def __init__(self, hs, transport_layer): - self.server_name = hs.hostname - - self.keyring = hs.get_keyring() - - self.transport_layer = transport_layer - - self.federation_client = self - - self.store = hs.get_datastore() - - self.handler = None - self.edu_handlers = {} - self.query_handlers = {} - - self._clock = hs.get_clock() - - self.transaction_actions = TransactionActions(self.store) - - self.hs = hs - - super(ReplicationLayer, self).__init__(hs) - - def __str__(self): - return "<ReplicationLayer(%s)>" % self.server_name diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 06c16ba4fa..a66a6b0692 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -1190,7 +1190,7 @@ GROUP_ATTESTATION_SERVLET_CLASSES = ( def register_servlets(hs, resource, authenticator, ratelimiter): for servletclass in FEDERATION_SERVLET_CLASSES: servletclass( - handler=hs.get_replication_layer(), + handler=hs.get_federation_server(), authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 0e83453851..40f3d24678 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -37,14 +37,15 @@ class DeviceHandler(BaseHandler): self.state = hs.get_state_handler() self._auth_handler = hs.get_auth_handler() self.federation_sender = hs.get_federation_sender() - self.federation = hs.get_replication_layer() self._edu_updater = DeviceListEduUpdater(hs, self) - self.federation.register_edu_handler( + federation_registry = hs.get_federation_registry() + + federation_registry.register_edu_handler( "m.device_list_update", self._edu_updater.incoming_device_list_update, ) - self.federation.register_query_handler( + federation_registry.register_query_handler( "user_devices", self.on_federation_query_user_devices, ) @@ -430,7 +431,7 @@ class DeviceListEduUpdater(object): def __init__(self, hs, device_handler): self.store = hs.get_datastore() - self.federation = hs.get_replication_layer() + self.federation = hs.get_federation_client() self.clock = hs.get_clock() self.device_handler = device_handler diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index d996aa90bb..f147a20b73 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -37,7 +37,7 @@ class DeviceMessageHandler(object): self.is_mine = hs.is_mine self.federation = hs.get_federation_sender() - hs.get_replication_layer().register_edu_handler( + hs.get_federation_registry().register_edu_handler( "m.direct_to_device", self.on_direct_to_device_edu ) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 8580ada60a..c5b6e75e03 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -36,8 +36,8 @@ class DirectoryHandler(BaseHandler): self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() - self.federation = hs.get_replication_layer() - self.federation.register_query_handler( + self.federation = hs.get_federation_client() + hs.get_federation_registry().register_query_handler( "directory", self.on_directory_query ) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 9aa95f89e6..31b1ece13e 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) class E2eKeysHandler(object): def __init__(self, hs): self.store = hs.get_datastore() - self.federation = hs.get_replication_layer() + self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() self.is_mine = hs.is_mine self.clock = hs.get_clock() @@ -40,7 +40,7 @@ class E2eKeysHandler(object): # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the # "query handler" interface. - self.federation.register_query_handler( + hs.get_federation_registry().register_query_handler( "client_keys", self.on_federation_query_client_keys ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 520612683e..080aca3d71 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -68,7 +68,7 @@ class FederationHandler(BaseHandler): self.hs = hs self.store = hs.get_datastore() - self.replication_layer = hs.get_replication_layer() + self.replication_layer = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname self.keyring = hs.get_keyring() @@ -78,8 +78,6 @@ class FederationHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() - self.replication_layer.set_handler(self) - # When joining a room we need to queue any events for that room up self.room_queues = {} self._room_pdu_linearizer = Linearizer("fed_room_pdu") diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index dd00d8a86c..4f97c8db79 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -13,7 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from twisted.internet import defer, reactor +from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError @@ -24,9 +25,10 @@ from synapse.types import ( UserID, RoomAlias, RoomStreamToken, ) from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter -from synapse.util.logcontext import preserve_fn +from synapse.util.logcontext import preserve_fn, run_in_background from synapse.util.metrics import measure_func from synapse.util.frozenutils import unfreeze +from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client from synapse.replication.http.send_event import send_event_to_master @@ -41,6 +43,36 @@ import ujson logger = logging.getLogger(__name__) +class PurgeStatus(object): + """Object tracking the status of a purge request + + This class contains information on the progress of a purge request, for + return by get_purge_status. + + Attributes: + status (int): Tracks whether this request has completed. One of + STATUS_{ACTIVE,COMPLETE,FAILED} + """ + + STATUS_ACTIVE = 0 + STATUS_COMPLETE = 1 + STATUS_FAILED = 2 + + STATUS_TEXT = { + STATUS_ACTIVE: "active", + STATUS_COMPLETE: "complete", + STATUS_FAILED: "failed", + } + + def __init__(self): + self.status = PurgeStatus.STATUS_ACTIVE + + def asdict(self): + return { + "status": PurgeStatus.STATUS_TEXT[self.status] + } + + class MessageHandler(BaseHandler): def __init__(self, hs): @@ -50,15 +82,88 @@ class MessageHandler(BaseHandler): self.clock = hs.get_clock() self.pagination_lock = ReadWriteLock() + self._purges_in_progress_by_room = set() + # map from purge id to PurgeStatus + self._purges_by_id = {} - @defer.inlineCallbacks - def purge_history(self, room_id, topological_ordering, - delete_local_events=False): - with (yield self.pagination_lock.write(room_id)): - yield self.store.purge_history( - room_id, topological_ordering, delete_local_events, + def start_purge_history(self, room_id, topological_ordering, + delete_local_events=False): + """Start off a history purge on a room. + + Args: + room_id (str): The room to purge from + + topological_ordering (int): minimum topo ordering to preserve + delete_local_events (bool): True to delete local events as well as + remote ones + + Returns: + str: unique ID for this purge transaction. + """ + if room_id in self._purges_in_progress_by_room: + raise SynapseError( + 400, + "History purge already in progress for %s" % (room_id, ), ) + purge_id = random_string(16) + + # we log the purge_id here so that it can be tied back to the + # request id in the log lines. + logger.info("[purge] starting purge_id %s", purge_id) + + self._purges_by_id[purge_id] = PurgeStatus() + run_in_background( + self._purge_history, + purge_id, room_id, topological_ordering, delete_local_events, + ) + return purge_id + + @defer.inlineCallbacks + def _purge_history(self, purge_id, room_id, topological_ordering, + delete_local_events): + """Carry out a history purge on a room. + + Args: + purge_id (str): The id for this purge + room_id (str): The room to purge from + topological_ordering (int): minimum topo ordering to preserve + delete_local_events (bool): True to delete local events as well as + remote ones + + Returns: + Deferred + """ + self._purges_in_progress_by_room.add(room_id) + try: + with (yield self.pagination_lock.write(room_id)): + yield self.store.purge_history( + room_id, topological_ordering, delete_local_events, + ) + logger.info("[purge] complete") + self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE + except Exception: + logger.error("[purge] failed: %s", Failure().getTraceback().rstrip()) + self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED + finally: + self._purges_in_progress_by_room.discard(room_id) + + # remove the purge from the list 24 hours after it completes + def clear_purge(): + del self._purges_by_id[purge_id] + reactor.callLater(24 * 3600, clear_purge) + + def get_purge_status(self, purge_id): + """Get the current status of an active purge + + Args: + purge_id (str): purge_id returned by start_purge_history + + Returns: + PurgeStatus|None + """ + return self._purges_by_id.get(purge_id) + @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, as_client_event=True, event_filter=None): @@ -562,7 +667,7 @@ class EventCreationHandler(object): event (FrozenEvent) context (EventContext) ratelimit (bool) - extra_users (list(str)): Any extra users to notify about event + extra_users (list(UserID)): Any extra users to notify about event """ try: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index cb158ba962..a5e501897c 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -93,29 +93,30 @@ class PresenceHandler(object): self.store = hs.get_datastore() self.wheel_timer = WheelTimer() self.notifier = hs.get_notifier() - self.replication = hs.get_replication_layer() self.federation = hs.get_federation_sender() self.state = hs.get_state_handler() - self.replication.register_edu_handler( + federation_registry = hs.get_federation_registry() + + federation_registry.register_edu_handler( "m.presence", self.incoming_presence ) - self.replication.register_edu_handler( + federation_registry.register_edu_handler( "m.presence_invite", lambda origin, content: self.invite_presence( observed_user=UserID.from_string(content["observed_user"]), observer_user=UserID.from_string(content["observer_user"]), ) ) - self.replication.register_edu_handler( + federation_registry.register_edu_handler( "m.presence_accept", lambda origin, content: self.accept_presence( observed_user=UserID.from_string(content["observed_user"]), observer_user=UserID.from_string(content["observer_user"]), ) ) - self.replication.register_edu_handler( + federation_registry.register_edu_handler( "m.presence_deny", lambda origin, content: self.deny_presence( observed_user=UserID.from_string(content["observed_user"]), diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index c9c2879038..cb710fe796 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -31,8 +31,8 @@ class ProfileHandler(BaseHandler): def __init__(self, hs): super(ProfileHandler, self).__init__(hs) - self.federation = hs.get_replication_layer() - self.federation.register_query_handler( + self.federation = hs.get_federation_client() + hs.get_federation_registry().register_query_handler( "profile", self.on_profile_query ) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 0525765272..3f215c2b4e 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -35,7 +35,7 @@ class ReceiptsHandler(BaseHandler): self.store = hs.get_datastore() self.hs = hs self.federation = hs.get_federation_sender() - hs.get_replication_layer().register_edu_handler( + hs.get_federation_registry().register_edu_handler( "m.receipt", self._received_remote_receipt ) self.clock = self.hs.get_clock() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 9021d4d57f..ed5939880a 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -446,16 +446,34 @@ class RegistrationHandler(BaseHandler): return self.hs.get_auth_handler() @defer.inlineCallbacks - def guest_access_token_for(self, medium, address, inviter_user_id): + def get_or_register_3pid_guest(self, medium, address, inviter_user_id): + """Get a guest access token for a 3PID, creating a guest account if + one doesn't already exist. + + Args: + medium (str) + address (str) + inviter_user_id (str): The user ID who is trying to invite the + 3PID + + Returns: + Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the + 3PID guest account. + """ access_token = yield self.store.get_3pid_guest_access_token(medium, address) if access_token: - defer.returnValue(access_token) + user_info = yield self.auth.get_user_by_access_token( + access_token + ) - _, access_token = yield self.register( + defer.returnValue((user_info["user"].to_string(), access_token)) + + user_id, access_token = yield self.register( generate_token=True, make_guest=True ) access_token = yield self.store.save_or_get_3pid_guest_access_token( medium, address, access_token, inviter_user_id ) - defer.returnValue(access_token) + + defer.returnValue((user_id, access_token)) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index dfa09141ed..5d81f59b44 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -409,7 +409,7 @@ class RoomListHandler(BaseHandler): def _get_remote_list_cached(self, server_name, limit=None, since_token=None, search_filter=None, include_all_networks=False, third_party_instance_id=None,): - repl_layer = self.hs.get_replication_layer() + repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search return repl_layer.get_public_rooms( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index ed3b97730d..0127cf4166 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -55,7 +55,6 @@ class RoomMemberHandler(object): self.registration_handler = hs.get_handlers().registration_handler self.profile_handler = hs.get_profile_handler() self.event_creation_hander = hs.get_event_creation_handler() - self.replication_layer = hs.get_replication_layer() self.member_linearizer = Linearizer(name="member") @@ -138,7 +137,20 @@ class RoomMemberHandler(object): defer.returnValue(event) @defer.inlineCallbacks - def remote_join(self, remote_room_hosts, room_id, user, content): + def _remote_join(self, remote_room_hosts, room_id, user, content): + """Try and join a room that this server is not in + + Args: + remote_room_hosts (list[str]): List of servers that can be used + to join via. + room_id (str): Room that we are trying to join + user (UserID): User who is trying to join + content (dict): A dict that should be used as the content of the + join event. + + Returns: + Deferred + """ if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") @@ -155,6 +167,43 @@ class RoomMemberHandler(object): yield user_joined_room(self.distributor, user, room_id) @defer.inlineCallbacks + def _remote_reject_invite(self, remote_room_hosts, room_id, target): + """Attempt to reject an invite for a room this server is not in. If we + fail to do so we locally mark the invite as rejected. + + Args: + remote_room_hosts (list[str]): List of servers to use to try and + reject invite + room_id (str) + target (UserID): The user rejecting the invite + + Returns: + Deferred[dict]: A dictionary to be returned to the client, may + include event_id etc, or nothing if we locally rejected + """ + fed_handler = self.federation_handler + try: + ret = yield fed_handler.do_remotely_reject_invite( + remote_room_hosts, + room_id, + target.to_string(), + ) + defer.returnValue(ret) + except Exception as e: + # if we were unable to reject the exception, just mark + # it as rejected on our end and plough ahead. + # + # The 'except' clause is very broad, but we need to + # capture everything from DNS failures upwards + # + logger.warn("Failed to reject invite: %s", e) + + yield self.store.locally_reject_invite( + target.to_string(), room_id + ) + defer.returnValue({}) + + @defer.inlineCallbacks def update_membership( self, requester, @@ -212,7 +261,7 @@ class RoomMemberHandler(object): # if this is a join with a 3pid signature, we may need to turn a 3pid # invite into a normal invite before we can handle the join. if third_party_signed is not None: - yield self.replication_layer.exchange_third_party_invite( + yield self.federation_handler.exchange_third_party_invite( third_party_signed["sender"], target.to_string(), room_id, @@ -292,7 +341,7 @@ class RoomMemberHandler(object): raise AuthError(403, "Guest access not allowed") if not is_host_in_room: - inviter = yield self.get_inviter(target.to_string(), room_id) + inviter = yield self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) @@ -306,7 +355,7 @@ class RoomMemberHandler(object): if requester.is_guest: content["kind"] = "guest" - ret = yield self.remote_join( + ret = yield self._remote_join( remote_room_hosts, room_id, target, content ) defer.returnValue(ret) @@ -314,7 +363,7 @@ class RoomMemberHandler(object): elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: # perhaps we've been invited - inviter = yield self.get_inviter(target.to_string(), room_id) + inviter = yield self._get_inviter(target.to_string(), room_id) if not inviter: raise SynapseError(404, "Not a known room") @@ -328,28 +377,10 @@ class RoomMemberHandler(object): else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - fed_handler = self.federation_handler - try: - ret = yield fed_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - target.to_string(), - ) - defer.returnValue(ret) - except Exception as e: - # if we were unable to reject the exception, just mark - # it as rejected on our end and plough ahead. - # - # The 'except' clause is very broad, but we need to - # capture everything from DNS failures upwards - # - logger.warn("Failed to reject invite: %s", e) - - yield self.store.locally_reject_invite( - target.to_string(), room_id - ) - - defer.returnValue({}) + res = yield self._remote_reject_invite( + remote_room_hosts, room_id, target, + ) + defer.returnValue(res) res = yield self._local_membership_update( requester=requester, @@ -496,7 +527,7 @@ class RoomMemberHandler(object): defer.returnValue((RoomID.from_string(room_id), servers)) @defer.inlineCallbacks - def get_inviter(self, user_id, room_id): + def _get_inviter(self, user_id, room_id): invite = yield self.store.get_invite_for_user_in_room( user_id=user_id, room_id=room_id, @@ -573,7 +604,7 @@ class RoomMemberHandler(object): if "mxid" in data: if "signatures" not in data: raise AuthError(401, "No signatures on 3pid binding") - yield self.verify_any_signature(data, id_server) + yield self._verify_any_signature(data, id_server) defer.returnValue(data["mxid"]) except IOError as e: @@ -581,7 +612,7 @@ class RoomMemberHandler(object): defer.returnValue(None) @defer.inlineCallbacks - def verify_any_signature(self, data, server_hostname): + def _verify_any_signature(self, data, server_hostname): if server_hostname not in data["signatures"]: raise AuthError(401, "No signature from server %s" % (server_hostname,)) for key_name, signature in data["signatures"][server_hostname].items(): @@ -735,20 +766,16 @@ class RoomMemberHandler(object): } if self.config.invite_3pid_guest: - registration_handler = self.registration_handler - guest_access_token = yield registration_handler.guest_access_token_for( + rh = self.registration_handler + guest_user_id, guest_access_token = yield rh.get_or_register_3pid_guest( medium=medium, address=address, inviter_user_id=inviter_user_id, ) - guest_user_info = yield self.auth.get_user_by_access_token( - guest_access_token - ) - invite_config.update({ "guest_access_token": guest_access_token, - "guest_user_id": guest_user_info["user"].to_string(), + "guest_user_id": guest_user_id, }) data = yield self.simple_http_client.post_urlencoded_get_json( diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 82dedbbc99..77c0cf146f 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -56,7 +56,7 @@ class TypingHandler(object): self.federation = hs.get_federation_sender() - hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu) + hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) hs.get_distributor().observe("user_left_room", self.user_left_room) diff --git a/synapse/http/server.py b/synapse/http/server.py index 165c684d0d..4b567215c8 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -59,6 +60,11 @@ response_count = metrics.register_counter( ) ) +requests_counter = metrics.register_counter( + "requests_received", + labels=["method", "servlet", ], +) + outgoing_responses_counter = metrics.register_counter( "responses", labels=["method", "code"], @@ -145,7 +151,8 @@ def wrap_request_handler(request_handler, include_metrics=False): # at the servlet name. For most requests that name will be # JsonResource (or a subclass), and JsonResource._async_render # will update it once it picks a servlet. - request_metrics.start(self.clock, name=self.__class__.__name__) + servlet_name = self.__class__.__name__ + request_metrics.start(self.clock, name=servlet_name) request_context.request = request_id with request.processing(): @@ -154,6 +161,7 @@ def wrap_request_handler(request_handler, include_metrics=False): if include_metrics: yield request_handler(self, request, request_metrics) else: + requests_counter.inc(request.method, servlet_name) yield request_handler(self, request) except CodeMessageException as e: code = e.code @@ -229,7 +237,7 @@ class JsonResource(HttpServer, resource.Resource): """ This implements the HttpServer interface and provides JSON support for Resources. - Register callbacks via register_path() + Register callbacks via register_paths() Callbacks can return a tuple of status code and a dict in which case the the dict will automatically be sent to the client as a JSON object. @@ -276,49 +284,59 @@ class JsonResource(HttpServer, resource.Resource): This checks if anyone has registered a callback for that method and path. """ - if request.method == "OPTIONS": - self._send_response(request, 200, {}) - return + callback, group_dict = self._get_handler_for_request(request) - # Loop through all the registered callbacks to check if the method - # and path regex match - for path_entry in self.path_regexs.get(request.method, []): - m = path_entry.pattern.match(request.path) - if not m: - continue + servlet_instance = getattr(callback, "__self__", None) + if servlet_instance is not None: + servlet_classname = servlet_instance.__class__.__name__ + else: + servlet_classname = "%r" % callback - # We found a match! First update the metrics object to indicate - # which servlet is handling the request. + request_metrics.name = servlet_classname + requests_counter.inc(request.method, servlet_classname) - callback = path_entry.callback + # Now trigger the callback. If it returns a response, we send it + # here. If it throws an exception, that is handled by the wrapper + # installed by @request_handler. - servlet_instance = getattr(callback, "__self__", None) - if servlet_instance is not None: - servlet_classname = servlet_instance.__class__.__name__ - else: - servlet_classname = "%r" % callback + kwargs = intern_dict({ + name: urllib.unquote(value).decode("UTF-8") if value else value + for name, value in group_dict.items() + }) - request_metrics.name = servlet_classname + callback_return = yield callback(request, **kwargs) + if callback_return is not None: + code, response = callback_return + self._send_response(request, code, response) - # Now trigger the callback. If it returns a response, we send it - # here. If it throws an exception, that is handled by the wrapper - # installed by @request_handler. + def _get_handler_for_request(self, request): + """Finds a callback method to handle the given request - kwargs = intern_dict({ - name: urllib.unquote(value).decode("UTF-8") if value else value - for name, value in m.groupdict().items() - }) + Args: + request (twisted.web.http.Request): + + Returns: + Tuple[Callable, dict[str, str]]: callback method, and the dict + mapping keys to path components as specified in the handler's + path match regexp. - callback_return = yield callback(request, **kwargs) - if callback_return is not None: - code, response = callback_return - self._send_response(request, code, response) + The callback will normally be a method registered via + register_paths, so will return (possibly via Deferred) either + None, or a tuple of (http code, response body). + """ + if request.method == "OPTIONS": + return _options_handler, {} - return + # Loop through all the registered callbacks to check if the method + # and path regex match + for path_entry in self.path_regexs.get(request.method, []): + m = path_entry.pattern.match(request.path) + if m: + # We found a match! + return path_entry.callback, m.groupdict() # Huh. No one wanted to handle that? Fiiiiiine. Send 400. - request_metrics.name = self.__class__.__name__ + ".UnrecognizedRequest" - raise UnrecognizedRequestError() + return _unrecognised_request_handler, {} def _send_response(self, request, code, response_json_object, response_code_message=None): @@ -335,6 +353,34 @@ class JsonResource(HttpServer, resource.Resource): ) +def _options_handler(request): + """Request handler for OPTIONS requests + + This is a request handler suitable for return from + _get_handler_for_request. It returns a 200 and an empty body. + + Args: + request (twisted.web.http.Request): + + Returns: + Tuple[int, dict]: http code, response body. + """ + return 200, {} + + +def _unrecognised_request_handler(request): + """Request handler for unrecognised requests + + This is a request handler suitable for return from + _get_handler_for_request. It actually just raises an + UnrecognizedRequestError. + + Args: + request (twisted.web.http.Request): + """ + raise UnrecognizedRequestError() + + class RequestMetrics(object): def start(self, clock, name): self.start = clock.time_msec() diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index e0cfb7d08f..50d99d7a5c 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -57,15 +57,31 @@ class Metrics(object): return metric def register_counter(self, *args, **kwargs): + """ + Returns: + CounterMetric + """ return self._register(CounterMetric, *args, **kwargs) def register_callback(self, *args, **kwargs): + """ + Returns: + CallbackMetric + """ return self._register(CallbackMetric, *args, **kwargs) def register_distribution(self, *args, **kwargs): + """ + Returns: + DistributionMetric + """ return self._register(DistributionMetric, *args, **kwargs) def register_cache(self, *args, **kwargs): + """ + Returns: + CacheMetric + """ return self._register(CacheMetric, *args, **kwargs) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 70f2fe456a..bbe2f967b7 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -25,7 +25,7 @@ from synapse.util.async import sleep from synapse.util.caches.response_cache import ResponseCache from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.metrics import Measure -from synapse.types import Requester +from synapse.types import Requester, UserID import logging import re @@ -46,7 +46,7 @@ def send_event_to_master(client, host, port, requester, event, context, event (FrozenEvent) context (EventContext) ratelimit (bool) - extra_users (list(str)): Any extra users to notify about event + extra_users (list(UserID)): Any extra users to notify about event """ uri = "http://%s:%s/_synapse/replication/send_event/%s" % ( host, port, event.event_id, @@ -59,7 +59,7 @@ def send_event_to_master(client, host, port, requester, event, context, "context": context.serialize(event), "requester": requester.serialize(), "ratelimit": ratelimit, - "extra_users": extra_users, + "extra_users": [u.to_string() for u in extra_users], } try: @@ -143,7 +143,7 @@ class ReplicationSendEventRestServlet(RestServlet): context = yield EventContext.deserialize(self.store, content["context"]) ratelimit = content["ratelimit"] - extra_users = content["extra_users"] + extra_users = [UserID.from_string(u) for u in content["extra_users"]] if requester.user: request.authenticated_entity = requester.user.to_string() diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index dcf6215dad..303419d281 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.api.constants import Membership -from synapse.api.errors import AuthError, SynapseError, Codes +from synapse.api.errors import AuthError, SynapseError, Codes, NotFoundError from synapse.types import UserID, create_requester from synapse.http.servlet import parse_json_object_from_request @@ -185,12 +185,43 @@ class PurgeHistoryRestServlet(ClientV1RestServlet): errcode=Codes.BAD_JSON, ) - yield self.handlers.message_handler.purge_history( + purge_id = yield self.handlers.message_handler.start_purge_history( room_id, depth, delete_local_events=delete_local_events, ) - defer.returnValue((200, {})) + defer.returnValue((200, { + "purge_id": purge_id, + })) + + +class PurgeHistoryStatusRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns( + "/admin/purge_history_status/(?P<purge_id>[^/]+)" + ) + + def __init__(self, hs): + """ + + Args: + hs (synapse.server.HomeServer) + """ + super(PurgeHistoryStatusRestServlet, self).__init__(hs) + self.handlers = hs.get_handlers() + + @defer.inlineCallbacks + def on_GET(self, request, purge_id): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + purge_status = self.handlers.message_handler.get_purge_status(purge_id) + if purge_status is None: + raise NotFoundError("purge id '%s' not found" % purge_id) + + defer.returnValue((200, purge_status.asdict())) class DeactivateAccountRestServlet(ClientV1RestServlet): @@ -561,6 +592,7 @@ class SearchUsersRestServlet(ClientV1RestServlet): def register_servlets(hs, http_server): WhoisRestServlet(hs).register(http_server) PurgeMediaCacheRestServlet(hs).register(http_server) + PurgeHistoryStatusRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server) UsersRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 9d745174c7..f8999d64d7 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -599,7 +599,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet): def register(self, http_server): # /rooms/$roomid/[invite|join|leave] PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" - "(?P<membership_action>join|invite|leave|ban|unban|kick|forget)") + "(?P<membership_action>join|invite|leave|ban|unban|kick)") register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks diff --git a/synapse/server.py b/synapse/server.py index 5b6effbe31..43c6e0a6d6 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -32,8 +32,10 @@ from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory from synapse.events.spamcheck import SpamChecker -from synapse.federation import initialize_http_replication +from synapse.federation.federation_client import FederationClient +from synapse.federation.federation_server import FederationServer from synapse.federation.send_queue import FederationRemoteSendQueue +from synapse.federation.federation_server import FederationHandlerRegistry from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transaction_queue import TransactionQueue from synapse.handlers import Handlers @@ -99,7 +101,8 @@ class HomeServer(object): DEPENDENCIES = [ 'http_client', 'db_pool', - 'replication_layer', + 'federation_client', + 'federation_server', 'handlers', 'v1auth', 'auth', @@ -147,6 +150,7 @@ class HomeServer(object): 'groups_attestation_renewer', 'spam_checker', 'room_member_handler', + 'federation_registry', ] def __init__(self, hostname, **kwargs): @@ -195,8 +199,11 @@ class HomeServer(object): def get_ratelimiter(self): return self.ratelimiter - def build_replication_layer(self): - return initialize_http_replication(self) + def build_federation_client(self): + return FederationClient(self) + + def build_federation_server(self): + return FederationServer(self) def build_handlers(self): return Handlers(self) @@ -387,6 +394,9 @@ class HomeServer(object): def build_room_member_handler(self): return RoomMemberHandler(self) + def build_federation_registry(self): + return FederationHandlerRegistry() + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 826fad307e..3890878170 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -283,10 +283,11 @@ class EventsStore(EventsWorkerStore): def _maybe_start_persisting(self, room_id): @defer.inlineCallbacks def persisting_queue(item): - yield self._persist_events( - item.events_and_contexts, - backfilled=item.backfilled, - ) + with Measure(self._clock, "persist_events"): + yield self._persist_events( + item.events_and_contexts, + backfilled=item.backfilled, + ) self._event_persist_queue.handle_queue(room_id, persisting_queue) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 405e6b6770..4ab16e18b2 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -245,8 +245,11 @@ class StateGroupWorkerStore(SQLBaseStore): if types: clause_to_args = [ ( - "AND type = ? AND state_key = ?" if state_key is not None else "AND type = ?", - (etype, state_key) if state_key is not None else (etype) + "AND type = ? AND state_key = ?", + (etype, state_key) + ) if state_key is not None else ( + "AND type = ?", + (etype,) ) for etype, state_key in types ] @@ -277,22 +280,25 @@ class StateGroupWorkerStore(SQLBaseStore): results[group][key] = event_id else: where_args = [] + where_clauses = [] + wildcard_types = False if types is not None: - where_clause = "AND (" for typ in types: if typ[1] is None: - where_clause += "(type = ?) OR " + where_clauses.append("(type = ?)") where_args.extend(typ[0]) + wildcard_types = True else: - where_clause += "(type = ? AND state_key = ?) OR " + where_clauses.append("(type = ? AND state_key = ?)") where_args.extend([typ[0], typ[1]]) if include_other_types: - where_clause += "(%s) OR " % ( - " AND ".join(["type <> ?"] * len(types)), + where_clauses.append( + "(" + " AND ".join(["type <> ?"] * len(types)) + ")" ) where_args.extend(t for (t, _) in types) - where_clause += "0)" # 0 to terminate the last OR + + where_clause = "AND (%s)" % (" OR ".join(where_clauses)) else: where_clause = "" @@ -322,9 +328,17 @@ class StateGroupWorkerStore(SQLBaseStore): if (typ, state_key) not in results[group] ) - # If the lengths match then we must have all the types, - # so no need to go walk further down the tree. - if types is not None and len(results[group]) == len(types): + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + types is not None and + not wildcard_types and + len(results[group]) == len(types) + ): break next_group = self._simple_select_one_onecol_txn( diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index a8dea15c1b..d660ec785b 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -292,36 +292,41 @@ class PreserveLoggingContext(object): def preserve_fn(f): - """Wraps a function, to ensure that the current context is restored after + """Function decorator which wraps the function with run_in_background""" + def g(*args, **kwargs): + return run_in_background(f, *args, **kwargs) + return g + + +def run_in_background(f, *args, **kwargs): + """Calls a function, ensuring that the current context is restored after return from the function, and that the sentinel context is set once the deferred returned by the funtion completes. Useful for wrapping functions that return a deferred which you don't yield on. """ - def g(*args, **kwargs): - current = LoggingContext.current_context() - res = f(*args, **kwargs) - if isinstance(res, defer.Deferred) and not res.called: - # The function will have reset the context before returning, so - # we need to restore it now. - LoggingContext.set_current_context(current) - - # The original context will be restored when the deferred - # completes, but there is nothing waiting for it, so it will - # get leaked into the reactor or some other function which - # wasn't expecting it. We therefore need to reset the context - # here. - # - # (If this feels asymmetric, consider it this way: we are - # effectively forking a new thread of execution. We are - # probably currently within a ``with LoggingContext()`` block, - # which is supposed to have a single entry and exit point. But - # by spawning off another deferred, we are effectively - # adding a new exit point.) - res.addBoth(_set_context_cb, LoggingContext.sentinel) - return res - return g + current = LoggingContext.current_context() + res = f(*args, **kwargs) + if isinstance(res, defer.Deferred) and not res.called: + # The function will have reset the context before returning, so + # we need to restore it now. + LoggingContext.set_current_context(current) + + # The original context will be restored when the deferred + # completes, but there is nothing waiting for it, so it will + # get leaked into the reactor or some other function which + # wasn't expecting it. We therefore need to reset the context + # here. + # + # (If this feels asymmetric, consider it this way: we are + # effectively forking a new thread of execution. We are + # probably currently within a ``with LoggingContext()`` block, + # which is supposed to have a single entry and exit point. But + # by spawning off another deferred, we are effectively + # adding a new exit point.) + res.addBoth(_set_context_cb, LoggingContext.sentinel) + return res def make_deferred_yieldable(deferred): |