diff --git a/synapse/__init__.py b/synapse/__init__.py
index 7628e7c505..580927abf4 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.19.3"
+__version__ = "0.20.0-rc1"
diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py
index 23eb6a1ec4..e8218d01ad 100755
--- a/synapse/app/synctl.py
+++ b/synapse/app/synctl.py
@@ -202,7 +202,8 @@ def main():
worker_app = worker_config["worker_app"]
worker_pidfile = worker_config["worker_pid_file"]
worker_daemonize = worker_config["worker_daemonize"]
- assert worker_daemonize # TODO print something more user friendly
+ assert worker_daemonize, "In config %r: expected '%s' to be True" % (
+ worker_configfile, "worker_daemonize")
worker_cache_factor = worker_config.get("synctl_cache_factor")
workers.append(Worker(
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index b0106a3597..7346206bb1 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.constants import EventTypes
+from synapse.util.caches.descriptors import cachedInlineCallbacks
from twisted.internet import defer
@@ -124,29 +125,23 @@ class ApplicationService(object):
raise ValueError(
"Expected bool for 'exclusive' in ns '%s'" % ns
)
- if not isinstance(regex_obj.get("regex"), basestring):
+ regex = regex_obj.get("regex")
+ if isinstance(regex, basestring):
+ regex_obj["regex"] = re.compile(regex) # Pre-compile regex
+ else:
raise ValueError(
"Expected string for 'regex' in ns '%s'" % ns
)
return namespaces
- def _matches_regex(self, test_string, namespace_key, return_obj=False):
- if not isinstance(test_string, basestring):
- logger.error(
- "Expected a string to test regex against, but got %s",
- test_string
- )
- return False
-
+ def _matches_regex(self, test_string, namespace_key):
for regex_obj in self.namespaces[namespace_key]:
- if re.match(regex_obj["regex"], test_string):
- if return_obj:
- return regex_obj
- return True
- return False
+ if regex_obj["regex"].match(test_string):
+ return regex_obj
+ return None
def _is_exclusive(self, ns_key, test_string):
- regex_obj = self._matches_regex(test_string, ns_key, return_obj=True)
+ regex_obj = self._matches_regex(test_string, ns_key)
if regex_obj:
return regex_obj["exclusive"]
return False
@@ -166,7 +161,14 @@ class ApplicationService(object):
if not store:
defer.returnValue(False)
- member_list = yield store.get_users_in_room(event.room_id)
+ does_match = yield self._matches_user_in_member_list(event.room_id, store)
+ defer.returnValue(does_match)
+
+ @cachedInlineCallbacks(num_args=1, cache_context=True)
+ def _matches_user_in_member_list(self, room_id, store, cache_context):
+ member_list = yield store.get_users_in_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
# check joined member events
for user_id in member_list:
@@ -219,10 +221,10 @@ class ApplicationService(object):
)
def is_interested_in_alias(self, alias):
- return self._matches_regex(alias, ApplicationService.NS_ALIASES)
+ return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
def is_interested_in_room(self, room_id):
- return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
+ return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
def is_exclusive_user(self, user_id):
return (
diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index eeb693027b..3a4e16fa96 100644
--- a/synapse/config/voip.py
+++ b/synapse/config/voip.py
@@ -23,6 +23,7 @@ class VoipConfig(Config):
self.turn_username = config.get("turn_username")
self.turn_password = config.get("turn_password")
self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
+ self.turn_allow_guests = config.get("turn_allow_guests", True)
def default_config(self, **kwargs):
return """\
@@ -41,4 +42,11 @@ class VoipConfig(Config):
# How long generated TURN credentials last
turn_user_lifetime: "1h"
+
+ # Whether guests should be allowed to use the TURN server.
+ # This defaults to True, otherwise VoIP will be unreliable for guests.
+ # However, it does introduce a slight security risk as it allows users to
+ # connect to arbitrary endpoints without having first signed up for a
+ # valid account (e.g. by passing a CAPTCHA).
+ turn_allow_guests: True
"""
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 510a176821..bc20b9c201 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -146,11 +146,15 @@ class FederationServer(FederationBase):
# check that it's actually being sent from a valid destination to
# workaround bug #1753 in 0.18.5 and 0.18.6
if transaction.origin != get_domain_from_id(pdu.event_id):
+ # We continue to accept join events from any server; this is
+ # necessary for the federation join dance to work correctly.
+ # (When we join over federation, the "helper" server is
+ # responsible for sending out the join event, rather than the
+ # origin. See bug #1893).
if not (
pdu.type == 'm.room.member' and
pdu.content and
- pdu.content.get("membership", None) == 'join' and
- self.hs.is_mine_id(pdu.state_key)
+ pdu.content.get("membership", None) == 'join'
):
logger.info(
"Discarding PDU %s from invalid origin %s",
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 888dd01240..6ed5ce9e10 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -28,7 +28,7 @@ from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.logcontext import (
- PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
+ preserve_fn, preserve_context_over_deferred
)
from synapse.util.metrics import measure_func
from synapse.util.logutils import log_function
@@ -394,11 +394,10 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
- with PreserveLoggingContext():
- self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=extra_users
- )
+ self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id,
+ extra_users=extra_users
+ )
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
@@ -916,11 +915,10 @@ class FederationHandler(BaseHandler):
origin, auth_chain, state, event
)
- with PreserveLoggingContext():
- self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=[joinee]
- )
+ self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id,
+ extra_users=[joinee]
+ )
logger.debug("Finished joining %s to %s", joinee, room_id)
finally:
@@ -1004,9 +1002,19 @@ class FederationHandler(BaseHandler):
)
event.internal_metadata.outlier = False
- # Send this event on behalf of the origin server since they may not
- # have an up to data view of the state of the room at this event so
- # will not know which servers to send the event to.
+ # Send this event on behalf of the origin server.
+ #
+ # The reasons we have the destination server rather than the origin
+ # server send it are slightly mysterious: the origin server should have
+ # all the neccessary state once it gets the response to the send_join,
+ # so it could send the event itself if it wanted to. It may be that
+ # doing it this way reduces failure modes, or avoids certain attacks
+ # where a new server selectively tells a subset of the federation that
+ # it has joined.
+ #
+ # The fact is that, as of the current writing, Synapse doesn't send out
+ # the join event over federation after joining, and changing it now
+ # would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin
context, event_stream_id, max_stream_id = yield self._handle_new_event(
@@ -1025,10 +1033,9 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
- with PreserveLoggingContext():
- self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id, extra_users=extra_users
- )
+ self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id, extra_users=extra_users
+ )
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN:
@@ -1074,11 +1081,10 @@ class FederationHandler(BaseHandler):
)
target_user = UserID.from_string(event.state_key)
- with PreserveLoggingContext():
- self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=[target_user],
- )
+ self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id,
+ extra_users=[target_user],
+ )
defer.returnValue(event)
@@ -1236,10 +1242,9 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
- with PreserveLoggingContext():
- self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id, extra_users=extra_users
- )
+ self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id, extra_users=extra_users
+ )
defer.returnValue(None)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 7a498af5a2..348056add5 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -612,7 +612,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def _notify():
yield run_on_reactor()
- yield self.notifier.on_new_room_event(
+ self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 4fda184b7a..48566187ab 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -210,7 +210,6 @@ class Notifier(object):
"""
self.replication_callbacks.append(cb)
- @preserve_fn
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]):
""" Used by handlers to inform the notifier something has happened
@@ -224,15 +223,13 @@ class Notifier(object):
until all previous events have been persisted before notifying
the client streams.
"""
- with PreserveLoggingContext():
- self.pending_new_room_events.append((
- room_stream_id, event, extra_users
- ))
- self._notify_pending_new_room_events(max_room_stream_id)
+ self.pending_new_room_events.append((
+ room_stream_id, event, extra_users
+ ))
+ self._notify_pending_new_room_events(max_room_stream_id)
- self.notify_replication()
+ self.notify_replication()
- @preserve_fn
def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
@@ -250,14 +247,16 @@ class Notifier(object):
else:
self._on_new_room_event(event, room_stream_id, extra_users)
- @preserve_fn
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
- self.appservice_handler.notify_interested_services(room_stream_id)
+ preserve_fn(self.appservice_handler.notify_interested_services)(
+ room_stream_id)
if self.federation_sender:
- self.federation_sender.notify_new_events(room_stream_id)
+ preserve_fn(self.federation_sender.notify_new_events)(
+ room_stream_id
+ )
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id)
@@ -268,7 +267,6 @@ class Notifier(object):
rooms=[event.room_id],
)
- @preserve_fn
def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend event wise.
@@ -295,7 +293,6 @@ class Notifier(object):
self.notify_replication()
- @preserve_fn
def on_new_replication_data(self):
"""Used to inform replication listeners that something has happend
without waking up any of the normal user event streams"""
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 4db76f18bd..4d88046579 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -17,6 +17,7 @@ import logging
import re
from synapse.types import UserID
+from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -125,6 +126,11 @@ class PushRuleEvaluatorForEvent(object):
return self._value_cache.get(dotted_key, None)
+# Caches (glob, word_boundary) -> regex for push. See _glob_matches
+regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
+register_cache("regex_push_cache", regex_cache)
+
+
def _glob_matches(glob, value, word_boundary=False):
"""Tests if value matches glob.
@@ -137,46 +143,63 @@ def _glob_matches(glob, value, word_boundary=False):
Returns:
bool
"""
- try:
- if IS_GLOB.search(glob):
- r = re.escape(glob)
-
- r = r.replace(r'\*', '.*?')
- r = r.replace(r'\?', '.')
-
- # handle [abc], [a-z] and [!a-z] style ranges.
- r = GLOB_REGEX.sub(
- lambda x: (
- '[%s%s]' % (
- x.group(1) and '^' or '',
- x.group(2).replace(r'\\\-', '-')
- )
- ),
- r,
- )
- if word_boundary:
- r = r"\b%s\b" % (r,)
- r = _compile_regex(r)
-
- return r.search(value)
- else:
- r = r + "$"
- r = _compile_regex(r)
-
- return r.match(value)
- elif word_boundary:
- r = re.escape(glob)
- r = r"\b%s\b" % (r,)
- r = _compile_regex(r)
- return r.search(value)
- else:
- return value.lower() == glob.lower()
+ try:
+ r = regex_cache.get((glob, word_boundary), None)
+ if not r:
+ r = _glob_to_re(glob, word_boundary)
+ regex_cache[(glob, word_boundary)] = r
+ return r.search(value)
except re.error:
logger.warn("Failed to parse glob to regex: %r", glob)
return False
+def _glob_to_re(glob, word_boundary):
+ """Generates regex for a given glob.
+
+ Args:
+ glob (string)
+ word_boundary (bool): Whether to match against word boundaries or entire
+ string. Defaults to False.
+
+ Returns:
+ regex object
+ """
+ if IS_GLOB.search(glob):
+ r = re.escape(glob)
+
+ r = r.replace(r'\*', '.*?')
+ r = r.replace(r'\?', '.')
+
+ # handle [abc], [a-z] and [!a-z] style ranges.
+ r = GLOB_REGEX.sub(
+ lambda x: (
+ '[%s%s]' % (
+ x.group(1) and '^' or '',
+ x.group(2).replace(r'\\\-', '-')
+ )
+ ),
+ r,
+ )
+ if word_boundary:
+ r = r"\b%s\b" % (r,)
+
+ return re.compile(r, flags=re.IGNORECASE)
+ else:
+ r = "^" + r + "$"
+
+ return re.compile(r, flags=re.IGNORECASE)
+ elif word_boundary:
+ r = re.escape(glob)
+ r = r"\b%s\b" % (r,)
+
+ return re.compile(r, flags=re.IGNORECASE)
+ else:
+ r = "^" + re.escape(glob) + "$"
+ return re.compile(r, flags=re.IGNORECASE)
+
+
def _flatten_dict(d, prefix=[], result={}):
for key, value in d.items():
if isinstance(value, basestring):
@@ -185,16 +208,3 @@ def _flatten_dict(d, prefix=[], result={}):
_flatten_dict(value, prefix=(prefix + [key]), result=result)
return result
-
-
-regex_cache = LruCache(5000)
-
-
-def _compile_regex(regex_str):
- r = regex_cache.get(regex_str, None)
- if r:
- return r
-
- r = re.compile(regex_str, flags=re.IGNORECASE)
- regex_cache[regex_str] = r
- return r
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 287df94b4f..6835f54e97 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -17,15 +17,12 @@ from twisted.internet import defer
from synapse.push.presentable_names import (
calculate_room_name, name_from_member_event
)
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@defer.inlineCallbacks
def get_badge_count(store, user_id):
- invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
- preserve_fn(store.get_invited_rooms_for_user)(user_id),
- preserve_fn(store.get_rooms_for_user)(user_id),
- ], consumeErrors=True))
+ invites = yield store.get_invited_rooms_for_user(user_id)
+ joins = yield store.get_rooms_for_user(user_id)
my_receipts_by_room = yield store.get_receipts_for_user(
user_id, "m.read",
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 03141c623c..c43b30b73a 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -28,7 +28,10 @@ class VoipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(
+ request,
+ self.hs.config.turn_allow_guests
+ )
turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 31f94bc6e9..6fceb23e26 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -36,7 +36,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- yield self.auth.get_user_by_req(request)
+ yield self.auth.get_user_by_req(request, allow_guest=True)
protocols = yield self.appservice_handler.get_3pe_protocols()
defer.returnValue((200, protocols))
@@ -54,7 +54,7 @@ class ThirdPartyProtocolServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, protocol):
- yield self.auth.get_user_by_req(request)
+ yield self.auth.get_user_by_req(request, allow_guest=True)
protocols = yield self.appservice_handler.get_3pe_protocols(
only_protocol=protocol,
@@ -77,7 +77,7 @@ class ThirdPartyUserServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, protocol):
- yield self.auth.get_user_by_req(request)
+ yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
fields.pop("access_token", None)
@@ -101,7 +101,7 @@ class ThirdPartyLocationServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, protocol):
- yield self.auth.get_user_by_req(request)
+ yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
fields.pop("access_token", None)
diff --git a/synapse/types.py b/synapse/types.py
index 9666f9d73f..c87ed813b9 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -216,9 +216,7 @@ class StreamToken(
return self
def copy_and_replace(self, key, new_value):
- d = self._asdict()
- d[key] = new_value
- return StreamToken(**d)
+ return self._replace(**{key: new_value})
StreamToken.START = StreamToken(
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 35380bf8ed..1453faf0ef 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -89,6 +89,11 @@ class ObservableDeferred(object):
deferred.addCallbacks(callback, errback)
def observe(self):
+ """Observe the underlying deferred.
+
+ Can return either a deferred if the underlying deferred is still pending
+ (or has failed), or the actual value. Callers may need to use maybeDeferred.
+ """
if not self._result:
d = defer.Deferred()
@@ -101,7 +106,7 @@ class ObservableDeferred(object):
return d
else:
success, res = self._result
- return defer.succeed(res) if success else defer.fail(res)
+ return res if success else defer.fail(res)
def observers(self):
return self._observers
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 19595df422..9d0d0be1f9 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -15,12 +15,9 @@
import logging
from synapse.util.async import ObservableDeferred
-from synapse.util import unwrapFirstError
+from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.logcontext import (
- PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
-)
from . import DEBUG_CACHES, register_cache
@@ -227,8 +224,20 @@ class _CacheDescriptorBase(object):
)
self.num_args = num_args
+
+ # list of the names of the args used as the cache key
self.arg_names = all_args[1:num_args + 1]
+ # self.arg_defaults is a map of arg name to its default value for each
+ # argument that has a default value
+ if arg_spec.defaults:
+ self.arg_defaults = dict(zip(
+ all_args[-len(arg_spec.defaults):],
+ arg_spec.defaults
+ ))
+ else:
+ self.arg_defaults = {}
+
if "cache_context" in self.arg_names:
raise Exception(
"cache_context arg cannot be included among the cache keys"
@@ -292,18 +301,31 @@ class CacheDescriptor(_CacheDescriptorBase):
iterable=self.iterable,
)
+ def get_cache_key(args, kwargs):
+ """Given some args/kwargs return a generator that resolves into
+ the cache_key.
+
+ We loop through each arg name, looking up if its in the `kwargs`,
+ otherwise using the next argument in `args`. If there are no more
+ args then we try looking the arg name up in the defaults
+ """
+ pos = 0
+ for nm in self.arg_names:
+ if nm in kwargs:
+ yield kwargs[nm]
+ elif pos < len(args):
+ yield args[pos]
+ pos += 1
+ else:
+ yield self.arg_defaults[nm]
+
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
- # Add temp cache_context so inspect.getcallargs doesn't explode
- if self.add_cache_context:
- kwargs["cache_context"] = None
-
- arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
- cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+ cache_key = tuple(get_cache_key(args, kwargs))
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
@@ -328,11 +350,9 @@ class CacheDescriptor(_CacheDescriptorBase):
defer.returnValue(cached_result)
observer.addCallback(check_result)
- return preserve_context_over_deferred(observer)
except KeyError:
ret = defer.maybeDeferred(
- preserve_context_over_fn,
- self.function_to_call,
+ logcontext.preserve_fn(self.function_to_call),
obj, *args, **kwargs
)
@@ -342,10 +362,14 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
- ret = ObservableDeferred(ret, consumeErrors=True)
- cache.set(cache_key, ret, callback=invalidate_callback)
+ result_d = ObservableDeferred(ret, consumeErrors=True)
+ cache.set(cache_key, result_d, callback=invalidate_callback)
+ observer = result_d.observe()
- return preserve_context_over_deferred(ret.observe())
+ if isinstance(observer, defer.Deferred):
+ return logcontext.make_deferred_yieldable(observer)
+ else:
+ return observer
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
@@ -362,7 +386,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes
- the list of missing keys to the wrapped fucntion.
+ the list of missing keys to the wrapped function.
+
+ Once wrapped, the function returns either a Deferred which resolves to
+ the list of results, or (if all results were cached), just the list of
+ results.
"""
def __init__(self, orig, cached_method_name, list_name, num_args=None,
@@ -433,8 +461,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
- preserve_context_over_fn,
- self.function_to_call,
+ logcontext.preserve_fn(self.function_to_call),
**args_to_call
)
@@ -443,8 +470,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
- with PreserveLoggingContext():
- observer = ret_d.observe()
+ observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
@@ -471,7 +497,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
results.update(res)
return results
- return preserve_context_over_deferred(defer.gatherResults(
+ return logcontext.make_deferred_yieldable(defer.gatherResults(
cached_defers.values(),
consumeErrors=True,
).addCallback(update_results_dict).addErrback(
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index ff67b1d794..990216145e 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -310,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs):
def preserve_context_over_deferred(deferred, context=None):
"""Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context.
+
+ Deprecated: this almost certainly doesn't do want you want, ie make
+ the deferred follow the synapse logcontext rules: try
+ ``make_deferred_yieldable`` instead.
"""
if context is None:
context = LoggingContext.current_context()
@@ -330,12 +334,8 @@ def preserve_fn(f):
LoggingContext.set_current_context(LoggingContext.sentinel)
return result
- # XXX: why is this here rather than inside g? surely we want to preserve
- # the context from the time the function was called, not when it was
- # wrapped?
- current = LoggingContext.current_context()
-
def g(*args, **kwargs):
+ current = LoggingContext.current_context()
res = f(*args, **kwargs)
if isinstance(res, defer.Deferred) and not res.called:
# The function will have reset the context before returning, so
@@ -359,6 +359,25 @@ def preserve_fn(f):
return g
+@defer.inlineCallbacks
+def make_deferred_yieldable(deferred):
+ """Given a deferred, make it follow the Synapse logcontext rules:
+
+ If the deferred has completed (or is not actually a Deferred), essentially
+ does nothing (just returns another completed deferred with the
+ result/failure).
+
+ If the deferred has not yet completed, resets the logcontext before
+ returning a deferred. Then, when the deferred completes, restores the
+ current logcontext before running callbacks/errbacks.
+
+ (This is more-or-less the opposite operation to preserve_fn.)
+ """
+ with PreserveLoggingContext():
+ r = yield deferred
+ defer.returnValue(r)
+
+
# modules to ignore in `logcontext_tracer`
_to_ignore = [
"synapse.util.logcontext",
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 31659156ae..c4dd9ae2c7 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
events ([synapse.events.EventBase]): list of events to filter
"""
forgotten = yield preserve_context_over_deferred(defer.gatherResults([
- preserve_fn(store.who_forgot_in_room)(
+ defer.maybeDeferred(
+ preserve_fn(store.who_forgot_in_room),
room_id,
)
for room_id in frozenset(e.room_id for e in events)
|