diff options
author | Erik Johnston <erik@matrix.org> | 2017-04-04 09:46:16 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2017-04-04 09:46:16 +0100 |
commit | 62b89daac6f4418d831d531fa40ad71aa0c74fb9 (patch) | |
tree | a5fa797ac7e65f85cae52cccd88451e5672257c6 /synapse | |
parent | Always advance stream tokens (diff) | |
parent | Merge pull request #2095 from matrix-org/rav/cull_log_preserves (diff) | |
download | synapse-62b89daac6f4418d831d531fa40ad71aa0c74fb9.tar.xz |
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/repl_tcp_server
Diffstat (limited to '')
-rw-r--r-- | synapse/__init__.py | 2 | ||||
-rwxr-xr-x | synapse/app/synctl.py | 3 | ||||
-rw-r--r-- | synapse/appservice/__init__.py | 38 | ||||
-rw-r--r-- | synapse/config/voip.py | 8 | ||||
-rw-r--r-- | synapse/federation/federation_server.py | 8 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 59 | ||||
-rw-r--r-- | synapse/handlers/message.py | 2 | ||||
-rw-r--r-- | synapse/notifier.py | 23 | ||||
-rw-r--r-- | synapse/push/push_rule_evaluator.py | 104 | ||||
-rw-r--r-- | synapse/push/push_tools.py | 7 | ||||
-rw-r--r-- | synapse/rest/client/v1/voip.py | 5 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/thirdparty.py | 8 | ||||
-rw-r--r-- | synapse/types.py | 4 | ||||
-rw-r--r-- | synapse/util/async.py | 7 | ||||
-rw-r--r-- | synapse/util/caches/descriptors.py | 70 | ||||
-rw-r--r-- | synapse/util/logcontext.py | 29 | ||||
-rw-r--r-- | synapse/visibility.py | 3 |
17 files changed, 228 insertions, 152 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index 7628e7c505..580927abf4 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.19.3" +__version__ = "0.20.0-rc1" diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index 23eb6a1ec4..e8218d01ad 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -202,7 +202,8 @@ def main(): worker_app = worker_config["worker_app"] worker_pidfile = worker_config["worker_pid_file"] worker_daemonize = worker_config["worker_daemonize"] - assert worker_daemonize # TODO print something more user friendly + assert worker_daemonize, "In config %r: expected '%s' to be True" % ( + worker_configfile, "worker_daemonize") worker_cache_factor = worker_config.get("synctl_cache_factor") workers.append(Worker( worker_app, worker_configfile, worker_pidfile, worker_cache_factor, diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index b0106a3597..7346206bb1 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from synapse.api.constants import EventTypes +from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer @@ -124,29 +125,23 @@ class ApplicationService(object): raise ValueError( "Expected bool for 'exclusive' in ns '%s'" % ns ) - if not isinstance(regex_obj.get("regex"), basestring): + regex = regex_obj.get("regex") + if isinstance(regex, basestring): + regex_obj["regex"] = re.compile(regex) # Pre-compile regex + else: raise ValueError( "Expected string for 'regex' in ns '%s'" % ns ) return namespaces - def _matches_regex(self, test_string, namespace_key, return_obj=False): - if not isinstance(test_string, basestring): - logger.error( - "Expected a string to test regex against, but got %s", - test_string - ) - return False - + def _matches_regex(self, test_string, namespace_key): for regex_obj in self.namespaces[namespace_key]: - if re.match(regex_obj["regex"], test_string): - if return_obj: - return regex_obj - return True - return False + if regex_obj["regex"].match(test_string): + return regex_obj + return None def _is_exclusive(self, ns_key, test_string): - regex_obj = self._matches_regex(test_string, ns_key, return_obj=True) + regex_obj = self._matches_regex(test_string, ns_key) if regex_obj: return regex_obj["exclusive"] return False @@ -166,7 +161,14 @@ class ApplicationService(object): if not store: defer.returnValue(False) - member_list = yield store.get_users_in_room(event.room_id) + does_match = yield self._matches_user_in_member_list(event.room_id, store) + defer.returnValue(does_match) + + @cachedInlineCallbacks(num_args=1, cache_context=True) + def _matches_user_in_member_list(self, room_id, store, cache_context): + member_list = yield store.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate + ) # check joined member events for user_id in member_list: @@ -219,10 +221,10 @@ class ApplicationService(object): ) def is_interested_in_alias(self, alias): - return self._matches_regex(alias, ApplicationService.NS_ALIASES) + return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES)) def is_interested_in_room(self, room_id): - return self._matches_regex(room_id, ApplicationService.NS_ROOMS) + return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS)) def is_exclusive_user(self, user_id): return ( diff --git a/synapse/config/voip.py b/synapse/config/voip.py index eeb693027b..3a4e16fa96 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -23,6 +23,7 @@ class VoipConfig(Config): self.turn_username = config.get("turn_username") self.turn_password = config.get("turn_password") self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"]) + self.turn_allow_guests = config.get("turn_allow_guests", True) def default_config(self, **kwargs): return """\ @@ -41,4 +42,11 @@ class VoipConfig(Config): # How long generated TURN credentials last turn_user_lifetime: "1h" + + # Whether guests should be allowed to use the TURN server. + # This defaults to True, otherwise VoIP will be unreliable for guests. + # However, it does introduce a slight security risk as it allows users to + # connect to arbitrary endpoints without having first signed up for a + # valid account (e.g. by passing a CAPTCHA). + turn_allow_guests: True """ diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 510a176821..bc20b9c201 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -146,11 +146,15 @@ class FederationServer(FederationBase): # check that it's actually being sent from a valid destination to # workaround bug #1753 in 0.18.5 and 0.18.6 if transaction.origin != get_domain_from_id(pdu.event_id): + # We continue to accept join events from any server; this is + # necessary for the federation join dance to work correctly. + # (When we join over federation, the "helper" server is + # responsible for sending out the join event, rather than the + # origin. See bug #1893). if not ( pdu.type == 'm.room.member' and pdu.content and - pdu.content.get("membership", None) == 'join' and - self.hs.is_mine_id(pdu.state_key) + pdu.content.get("membership", None) == 'join' ): logger.info( "Discarding PDU %s from invalid origin %s", diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 888dd01240..6ed5ce9e10 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -28,7 +28,7 @@ from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_fn, preserve_context_over_deferred + preserve_fn, preserve_context_over_deferred ) from synapse.util.metrics import measure_func from synapse.util.logutils import log_function @@ -394,11 +394,10 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(target_user_id) extra_users.append(target_user) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users - ) + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=extra_users + ) if event.type == EventTypes.Member: if event.membership == Membership.JOIN: @@ -916,11 +915,10 @@ class FederationHandler(BaseHandler): origin, auth_chain, state, event ) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=[joinee] - ) + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=[joinee] + ) logger.debug("Finished joining %s to %s", joinee, room_id) finally: @@ -1004,9 +1002,19 @@ class FederationHandler(BaseHandler): ) event.internal_metadata.outlier = False - # Send this event on behalf of the origin server since they may not - # have an up to data view of the state of the room at this event so - # will not know which servers to send the event to. + # Send this event on behalf of the origin server. + # + # The reasons we have the destination server rather than the origin + # server send it are slightly mysterious: the origin server should have + # all the neccessary state once it gets the response to the send_join, + # so it could send the event itself if it wanted to. It may be that + # doing it this way reduces failure modes, or avoids certain attacks + # where a new server selectively tells a subset of the federation that + # it has joined. + # + # The fact is that, as of the current writing, Synapse doesn't send out + # the join event over federation after joining, and changing it now + # would introduce the danger of backwards-compatibility problems. event.internal_metadata.send_on_behalf_of = origin context, event_stream_id, max_stream_id = yield self._handle_new_event( @@ -1025,10 +1033,9 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(target_user_id) extra_users.append(target_user) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users - ) + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, extra_users=extra_users + ) if event.type == EventTypes.Member: if event.content["membership"] == Membership.JOIN: @@ -1074,11 +1081,10 @@ class FederationHandler(BaseHandler): ) target_user = UserID.from_string(event.state_key) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=[target_user], - ) + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=[target_user], + ) defer.returnValue(event) @@ -1236,10 +1242,9 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(target_user_id) extra_users.append(target_user) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users - ) + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, extra_users=extra_users + ) defer.returnValue(None) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7a498af5a2..348056add5 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -612,7 +612,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def _notify(): yield run_on_reactor() - yield self.notifier.on_new_room_event( + self.notifier.on_new_room_event( event, event_stream_id, max_stream_id, extra_users=extra_users ) diff --git a/synapse/notifier.py b/synapse/notifier.py index 4fda184b7a..48566187ab 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -210,7 +210,6 @@ class Notifier(object): """ self.replication_callbacks.append(cb) - @preserve_fn def on_new_room_event(self, event, room_stream_id, max_room_stream_id, extra_users=[]): """ Used by handlers to inform the notifier something has happened @@ -224,15 +223,13 @@ class Notifier(object): until all previous events have been persisted before notifying the client streams. """ - with PreserveLoggingContext(): - self.pending_new_room_events.append(( - room_stream_id, event, extra_users - )) - self._notify_pending_new_room_events(max_room_stream_id) + self.pending_new_room_events.append(( + room_stream_id, event, extra_users + )) + self._notify_pending_new_room_events(max_room_stream_id) - self.notify_replication() + self.notify_replication() - @preserve_fn def _notify_pending_new_room_events(self, max_room_stream_id): """Notify for the room events that were queued waiting for a previous event to be persisted. @@ -250,14 +247,16 @@ class Notifier(object): else: self._on_new_room_event(event, room_stream_id, extra_users) - @preserve_fn def _on_new_room_event(self, event, room_stream_id, extra_users=[]): """Notify any user streams that are interested in this room event""" # poke any interested application service. - self.appservice_handler.notify_interested_services(room_stream_id) + preserve_fn(self.appservice_handler.notify_interested_services)( + room_stream_id) if self.federation_sender: - self.federation_sender.notify_new_events(room_stream_id) + preserve_fn(self.federation_sender.notify_new_events)( + room_stream_id + ) if event.type == EventTypes.Member and event.membership == Membership.JOIN: self._user_joined_room(event.state_key, event.room_id) @@ -268,7 +267,6 @@ class Notifier(object): rooms=[event.room_id], ) - @preserve_fn def on_new_event(self, stream_key, new_token, users=[], rooms=[]): """ Used to inform listeners that something has happend event wise. @@ -295,7 +293,6 @@ class Notifier(object): self.notify_replication() - @preserve_fn def on_new_replication_data(self): """Used to inform replication listeners that something has happend without waking up any of the normal user event streams""" diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 4db76f18bd..4d88046579 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -17,6 +17,7 @@ import logging import re from synapse.types import UserID +from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -125,6 +126,11 @@ class PushRuleEvaluatorForEvent(object): return self._value_cache.get(dotted_key, None) +# Caches (glob, word_boundary) -> regex for push. See _glob_matches +regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR) +register_cache("regex_push_cache", regex_cache) + + def _glob_matches(glob, value, word_boundary=False): """Tests if value matches glob. @@ -137,46 +143,63 @@ def _glob_matches(glob, value, word_boundary=False): Returns: bool """ - try: - if IS_GLOB.search(glob): - r = re.escape(glob) - - r = r.replace(r'\*', '.*?') - r = r.replace(r'\?', '.') - - # handle [abc], [a-z] and [!a-z] style ranges. - r = GLOB_REGEX.sub( - lambda x: ( - '[%s%s]' % ( - x.group(1) and '^' or '', - x.group(2).replace(r'\\\-', '-') - ) - ), - r, - ) - if word_boundary: - r = r"\b%s\b" % (r,) - r = _compile_regex(r) - - return r.search(value) - else: - r = r + "$" - r = _compile_regex(r) - - return r.match(value) - elif word_boundary: - r = re.escape(glob) - r = r"\b%s\b" % (r,) - r = _compile_regex(r) - return r.search(value) - else: - return value.lower() == glob.lower() + try: + r = regex_cache.get((glob, word_boundary), None) + if not r: + r = _glob_to_re(glob, word_boundary) + regex_cache[(glob, word_boundary)] = r + return r.search(value) except re.error: logger.warn("Failed to parse glob to regex: %r", glob) return False +def _glob_to_re(glob, word_boundary): + """Generates regex for a given glob. + + Args: + glob (string) + word_boundary (bool): Whether to match against word boundaries or entire + string. Defaults to False. + + Returns: + regex object + """ + if IS_GLOB.search(glob): + r = re.escape(glob) + + r = r.replace(r'\*', '.*?') + r = r.replace(r'\?', '.') + + # handle [abc], [a-z] and [!a-z] style ranges. + r = GLOB_REGEX.sub( + lambda x: ( + '[%s%s]' % ( + x.group(1) and '^' or '', + x.group(2).replace(r'\\\-', '-') + ) + ), + r, + ) + if word_boundary: + r = r"\b%s\b" % (r,) + + return re.compile(r, flags=re.IGNORECASE) + else: + r = "^" + r + "$" + + return re.compile(r, flags=re.IGNORECASE) + elif word_boundary: + r = re.escape(glob) + r = r"\b%s\b" % (r,) + + return re.compile(r, flags=re.IGNORECASE) + else: + r = "^" + re.escape(glob) + "$" + return re.compile(r, flags=re.IGNORECASE) + + def _flatten_dict(d, prefix=[], result={}): for key, value in d.items(): if isinstance(value, basestring): @@ -185,16 +208,3 @@ def _flatten_dict(d, prefix=[], result={}): _flatten_dict(value, prefix=(prefix + [key]), result=result) return result - - -regex_cache = LruCache(5000) - - -def _compile_regex(regex_str): - r = regex_cache.get(regex_str, None) - if r: - return r - - r = re.compile(regex_str, flags=re.IGNORECASE) - regex_cache[regex_str] = r - return r diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 287df94b4f..6835f54e97 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -17,15 +17,12 @@ from twisted.internet import defer from synapse.push.presentable_names import ( calculate_room_name, name_from_member_event ) -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred @defer.inlineCallbacks def get_badge_count(store, user_id): - invites, joins = yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(store.get_invited_rooms_for_user)(user_id), - preserve_fn(store.get_rooms_for_user)(user_id), - ], consumeErrors=True)) + invites = yield store.get_invited_rooms_for_user(user_id) + joins = yield store.get_rooms_for_user(user_id) my_receipts_by_room = yield store.get_receipts_for_user( user_id, "m.read", diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 03141c623c..c43b30b73a 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -28,7 +28,10 @@ class VoipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req( + request, + self.hs.config.turn_allow_guests + ) turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 31f94bc6e9..6fceb23e26 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -36,7 +36,7 @@ class ThirdPartyProtocolsServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) protocols = yield self.appservice_handler.get_3pe_protocols() defer.returnValue((200, protocols)) @@ -54,7 +54,7 @@ class ThirdPartyProtocolServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) protocols = yield self.appservice_handler.get_3pe_protocols( only_protocol=protocol, @@ -77,7 +77,7 @@ class ThirdPartyUserServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop("access_token", None) @@ -101,7 +101,7 @@ class ThirdPartyLocationServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop("access_token", None) diff --git a/synapse/types.py b/synapse/types.py index 9666f9d73f..c87ed813b9 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -216,9 +216,7 @@ class StreamToken( return self def copy_and_replace(self, key, new_value): - d = self._asdict() - d[key] = new_value - return StreamToken(**d) + return self._replace(**{key: new_value}) StreamToken.START = StreamToken( diff --git a/synapse/util/async.py b/synapse/util/async.py index 35380bf8ed..1453faf0ef 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -89,6 +89,11 @@ class ObservableDeferred(object): deferred.addCallbacks(callback, errback) def observe(self): + """Observe the underlying deferred. + + Can return either a deferred if the underlying deferred is still pending + (or has failed), or the actual value. Callers may need to use maybeDeferred. + """ if not self._result: d = defer.Deferred() @@ -101,7 +106,7 @@ class ObservableDeferred(object): return d else: success, res = self._result - return defer.succeed(res) if success else defer.fail(res) + return res if success else defer.fail(res) def observers(self): return self._observers diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 19595df422..9d0d0be1f9 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -15,12 +15,9 @@ import logging from synapse.util.async import ObservableDeferred -from synapse.util import unwrapFirstError +from synapse.util import unwrapFirstError, logcontext from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry -from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn -) from . import DEBUG_CACHES, register_cache @@ -227,8 +224,20 @@ class _CacheDescriptorBase(object): ) self.num_args = num_args + + # list of the names of the args used as the cache key self.arg_names = all_args[1:num_args + 1] + # self.arg_defaults is a map of arg name to its default value for each + # argument that has a default value + if arg_spec.defaults: + self.arg_defaults = dict(zip( + all_args[-len(arg_spec.defaults):], + arg_spec.defaults + )) + else: + self.arg_defaults = {} + if "cache_context" in self.arg_names: raise Exception( "cache_context arg cannot be included among the cache keys" @@ -292,18 +301,31 @@ class CacheDescriptor(_CacheDescriptorBase): iterable=self.iterable, ) + def get_cache_key(args, kwargs): + """Given some args/kwargs return a generator that resolves into + the cache_key. + + We loop through each arg name, looking up if its in the `kwargs`, + otherwise using the next argument in `args`. If there are no more + args then we try looking the arg name up in the defaults + """ + pos = 0 + for nm in self.arg_names: + if nm in kwargs: + yield kwargs[nm] + elif pos < len(args): + yield args[pos] + pos += 1 + else: + yield self.arg_defaults[nm] + @functools.wraps(self.orig) def wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) - # Add temp cache_context so inspect.getcallargs doesn't explode - if self.add_cache_context: - kwargs["cache_context"] = None - - arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) + cache_key = tuple(get_cache_key(args, kwargs)) # Add our own `cache_context` to argument list if the wrapped function # has asked for one @@ -328,11 +350,9 @@ class CacheDescriptor(_CacheDescriptorBase): defer.returnValue(cached_result) observer.addCallback(check_result) - return preserve_context_over_deferred(observer) except KeyError: ret = defer.maybeDeferred( - preserve_context_over_fn, - self.function_to_call, + logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs ) @@ -342,10 +362,14 @@ class CacheDescriptor(_CacheDescriptorBase): ret.addErrback(onErr) - ret = ObservableDeferred(ret, consumeErrors=True) - cache.set(cache_key, ret, callback=invalidate_callback) + result_d = ObservableDeferred(ret, consumeErrors=True) + cache.set(cache_key, result_d, callback=invalidate_callback) + observer = result_d.observe() - return preserve_context_over_deferred(ret.observe()) + if isinstance(observer, defer.Deferred): + return logcontext.make_deferred_yieldable(observer) + else: + return observer wrapped.invalidate = cache.invalidate wrapped.invalidate_all = cache.invalidate_all @@ -362,7 +386,11 @@ class CacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given a list of keys it looks in the cache to find any hits, then passes - the list of missing keys to the wrapped fucntion. + the list of missing keys to the wrapped function. + + Once wrapped, the function returns either a Deferred which resolves to + the list of results, or (if all results were cached), just the list of + results. """ def __init__(self, orig, cached_method_name, list_name, num_args=None, @@ -433,8 +461,7 @@ class CacheListDescriptor(_CacheDescriptorBase): args_to_call[self.list_name] = missing ret_d = defer.maybeDeferred( - preserve_context_over_fn, - self.function_to_call, + logcontext.preserve_fn(self.function_to_call), **args_to_call ) @@ -443,8 +470,7 @@ class CacheListDescriptor(_CacheDescriptorBase): # We need to create deferreds for each arg in the list so that # we can insert the new deferred into the cache. for arg in missing: - with PreserveLoggingContext(): - observer = ret_d.observe() + observer = ret_d.observe() observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer = ObservableDeferred(observer) @@ -471,7 +497,7 @@ class CacheListDescriptor(_CacheDescriptorBase): results.update(res) return results - return preserve_context_over_deferred(defer.gatherResults( + return logcontext.make_deferred_yieldable(defer.gatherResults( cached_defers.values(), consumeErrors=True, ).addCallback(update_results_dict).addErrback( diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index ff67b1d794..990216145e 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -310,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs): def preserve_context_over_deferred(deferred, context=None): """Given a deferred wrap it such that any callbacks added later to it will be invoked with the current context. + + Deprecated: this almost certainly doesn't do want you want, ie make + the deferred follow the synapse logcontext rules: try + ``make_deferred_yieldable`` instead. """ if context is None: context = LoggingContext.current_context() @@ -330,12 +334,8 @@ def preserve_fn(f): LoggingContext.set_current_context(LoggingContext.sentinel) return result - # XXX: why is this here rather than inside g? surely we want to preserve - # the context from the time the function was called, not when it was - # wrapped? - current = LoggingContext.current_context() - def g(*args, **kwargs): + current = LoggingContext.current_context() res = f(*args, **kwargs) if isinstance(res, defer.Deferred) and not res.called: # The function will have reset the context before returning, so @@ -359,6 +359,25 @@ def preserve_fn(f): return g +@defer.inlineCallbacks +def make_deferred_yieldable(deferred): + """Given a deferred, make it follow the Synapse logcontext rules: + + If the deferred has completed (or is not actually a Deferred), essentially + does nothing (just returns another completed deferred with the + result/failure). + + If the deferred has not yet completed, resets the logcontext before + returning a deferred. Then, when the deferred completes, restores the + current logcontext before running callbacks/errbacks. + + (This is more-or-less the opposite operation to preserve_fn.) + """ + with PreserveLoggingContext(): + r = yield deferred + defer.returnValue(r) + + # modules to ignore in `logcontext_tracer` _to_ignore = [ "synapse.util.logcontext", diff --git a/synapse/visibility.py b/synapse/visibility.py index 31659156ae..c4dd9ae2c7 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state): events ([synapse.events.EventBase]): list of events to filter """ forgotten = yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(store.who_forgot_in_room)( + defer.maybeDeferred( + preserve_fn(store.who_forgot_in_room), room_id, ) for room_id in frozenset(e.room_id for e in events) |