summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rwxr-xr-xsynapse/app/synctl.py3
-rw-r--r--synapse/appservice/__init__.py38
-rw-r--r--synapse/config/voip.py8
-rw-r--r--synapse/federation/federation_server.py8
-rw-r--r--synapse/handlers/federation.py59
-rw-r--r--synapse/handlers/message.py2
-rw-r--r--synapse/notifier.py23
-rw-r--r--synapse/push/push_rule_evaluator.py104
-rw-r--r--synapse/push/push_tools.py7
-rw-r--r--synapse/rest/client/v1/voip.py5
-rw-r--r--synapse/rest/client/v2_alpha/thirdparty.py8
-rw-r--r--synapse/types.py4
-rw-r--r--synapse/util/async.py7
-rw-r--r--synapse/util/caches/descriptors.py70
-rw-r--r--synapse/util/logcontext.py29
-rw-r--r--synapse/visibility.py3
17 files changed, 228 insertions, 152 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 7628e7c505..580927abf4 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
 """ This is a reference implementation of a Matrix home server.
 """
 
-__version__ = "0.19.3"
+__version__ = "0.20.0-rc1"
diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py
index 23eb6a1ec4..e8218d01ad 100755
--- a/synapse/app/synctl.py
+++ b/synapse/app/synctl.py
@@ -202,7 +202,8 @@ def main():
         worker_app = worker_config["worker_app"]
         worker_pidfile = worker_config["worker_pid_file"]
         worker_daemonize = worker_config["worker_daemonize"]
-        assert worker_daemonize  # TODO print something more user friendly
+        assert worker_daemonize, "In config %r: expected '%s' to be True" % (
+            worker_configfile, "worker_daemonize")
         worker_cache_factor = worker_config.get("synctl_cache_factor")
         workers.append(Worker(
             worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index b0106a3597..7346206bb1 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from synapse.api.constants import EventTypes
+from synapse.util.caches.descriptors import cachedInlineCallbacks
 
 from twisted.internet import defer
 
@@ -124,29 +125,23 @@ class ApplicationService(object):
                     raise ValueError(
                         "Expected bool for 'exclusive' in ns '%s'" % ns
                     )
-                if not isinstance(regex_obj.get("regex"), basestring):
+                regex = regex_obj.get("regex")
+                if isinstance(regex, basestring):
+                    regex_obj["regex"] = re.compile(regex)  # Pre-compile regex
+                else:
                     raise ValueError(
                         "Expected string for 'regex' in ns '%s'" % ns
                     )
         return namespaces
 
-    def _matches_regex(self, test_string, namespace_key, return_obj=False):
-        if not isinstance(test_string, basestring):
-            logger.error(
-                "Expected a string to test regex against, but got %s",
-                test_string
-            )
-            return False
-
+    def _matches_regex(self, test_string, namespace_key):
         for regex_obj in self.namespaces[namespace_key]:
-            if re.match(regex_obj["regex"], test_string):
-                if return_obj:
-                    return regex_obj
-                return True
-        return False
+            if regex_obj["regex"].match(test_string):
+                return regex_obj
+        return None
 
     def _is_exclusive(self, ns_key, test_string):
-        regex_obj = self._matches_regex(test_string, ns_key, return_obj=True)
+        regex_obj = self._matches_regex(test_string, ns_key)
         if regex_obj:
             return regex_obj["exclusive"]
         return False
@@ -166,7 +161,14 @@ class ApplicationService(object):
         if not store:
             defer.returnValue(False)
 
-        member_list = yield store.get_users_in_room(event.room_id)
+        does_match = yield self._matches_user_in_member_list(event.room_id, store)
+        defer.returnValue(does_match)
+
+    @cachedInlineCallbacks(num_args=1, cache_context=True)
+    def _matches_user_in_member_list(self, room_id, store, cache_context):
+        member_list = yield store.get_users_in_room(
+            room_id, on_invalidate=cache_context.invalidate
+        )
 
         # check joined member events
         for user_id in member_list:
@@ -219,10 +221,10 @@ class ApplicationService(object):
         )
 
     def is_interested_in_alias(self, alias):
-        return self._matches_regex(alias, ApplicationService.NS_ALIASES)
+        return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
 
     def is_interested_in_room(self, room_id):
-        return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
+        return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
 
     def is_exclusive_user(self, user_id):
         return (
diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index eeb693027b..3a4e16fa96 100644
--- a/synapse/config/voip.py
+++ b/synapse/config/voip.py
@@ -23,6 +23,7 @@ class VoipConfig(Config):
         self.turn_username = config.get("turn_username")
         self.turn_password = config.get("turn_password")
         self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
+        self.turn_allow_guests = config.get("turn_allow_guests", True)
 
     def default_config(self, **kwargs):
         return """\
@@ -41,4 +42,11 @@ class VoipConfig(Config):
 
         # How long generated TURN credentials last
         turn_user_lifetime: "1h"
+
+        # Whether guests should be allowed to use the TURN server.
+        # This defaults to True, otherwise VoIP will be unreliable for guests.
+        # However, it does introduce a slight security risk as it allows users to
+        # connect to arbitrary endpoints without having first signed up for a
+        # valid account (e.g. by passing a CAPTCHA).
+        turn_allow_guests: True
         """
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 510a176821..bc20b9c201 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -146,11 +146,15 @@ class FederationServer(FederationBase):
             # check that it's actually being sent from a valid destination to
             # workaround bug #1753 in 0.18.5 and 0.18.6
             if transaction.origin != get_domain_from_id(pdu.event_id):
+                # We continue to accept join events from any server; this is
+                # necessary for the federation join dance to work correctly.
+                # (When we join over federation, the "helper" server is
+                # responsible for sending out the join event, rather than the
+                # origin. See bug #1893).
                 if not (
                     pdu.type == 'm.room.member' and
                     pdu.content and
-                    pdu.content.get("membership", None) == 'join' and
-                    self.hs.is_mine_id(pdu.state_key)
+                    pdu.content.get("membership", None) == 'join'
                 ):
                     logger.info(
                         "Discarding PDU %s from invalid origin %s",
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 888dd01240..6ed5ce9e10 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -28,7 +28,7 @@ from synapse.api.constants import EventTypes, Membership, RejectedReason
 from synapse.events.validator import EventValidator
 from synapse.util import unwrapFirstError
 from synapse.util.logcontext import (
-    PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
+    preserve_fn, preserve_context_over_deferred
 )
 from synapse.util.metrics import measure_func
 from synapse.util.logutils import log_function
@@ -394,11 +394,10 @@ class FederationHandler(BaseHandler):
             target_user = UserID.from_string(target_user_id)
             extra_users.append(target_user)
 
-        with PreserveLoggingContext():
-            self.notifier.on_new_room_event(
-                event, event_stream_id, max_stream_id,
-                extra_users=extra_users
-            )
+        self.notifier.on_new_room_event(
+            event, event_stream_id, max_stream_id,
+            extra_users=extra_users
+        )
 
         if event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
@@ -916,11 +915,10 @@ class FederationHandler(BaseHandler):
                 origin, auth_chain, state, event
             )
 
-            with PreserveLoggingContext():
-                self.notifier.on_new_room_event(
-                    event, event_stream_id, max_stream_id,
-                    extra_users=[joinee]
-                )
+            self.notifier.on_new_room_event(
+                event, event_stream_id, max_stream_id,
+                extra_users=[joinee]
+            )
 
             logger.debug("Finished joining %s to %s", joinee, room_id)
         finally:
@@ -1004,9 +1002,19 @@ class FederationHandler(BaseHandler):
         )
 
         event.internal_metadata.outlier = False
-        # Send this event on behalf of the origin server since they may not
-        # have an up to data view of the state of the room at this event so
-        # will not know which servers to send the event to.
+        # Send this event on behalf of the origin server.
+        #
+        # The reasons we have the destination server rather than the origin
+        # server send it are slightly mysterious: the origin server should have
+        # all the neccessary state once it gets the response to the send_join,
+        # so it could send the event itself if it wanted to. It may be that
+        # doing it this way reduces failure modes, or avoids certain attacks
+        # where a new server selectively tells a subset of the federation that
+        # it has joined.
+        #
+        # The fact is that, as of the current writing, Synapse doesn't send out
+        # the join event over federation after joining, and changing it now
+        # would introduce the danger of backwards-compatibility problems.
         event.internal_metadata.send_on_behalf_of = origin
 
         context, event_stream_id, max_stream_id = yield self._handle_new_event(
@@ -1025,10 +1033,9 @@ class FederationHandler(BaseHandler):
             target_user = UserID.from_string(target_user_id)
             extra_users.append(target_user)
 
-        with PreserveLoggingContext():
-            self.notifier.on_new_room_event(
-                event, event_stream_id, max_stream_id, extra_users=extra_users
-            )
+        self.notifier.on_new_room_event(
+            event, event_stream_id, max_stream_id, extra_users=extra_users
+        )
 
         if event.type == EventTypes.Member:
             if event.content["membership"] == Membership.JOIN:
@@ -1074,11 +1081,10 @@ class FederationHandler(BaseHandler):
         )
 
         target_user = UserID.from_string(event.state_key)
-        with PreserveLoggingContext():
-            self.notifier.on_new_room_event(
-                event, event_stream_id, max_stream_id,
-                extra_users=[target_user],
-            )
+        self.notifier.on_new_room_event(
+            event, event_stream_id, max_stream_id,
+            extra_users=[target_user],
+        )
 
         defer.returnValue(event)
 
@@ -1236,10 +1242,9 @@ class FederationHandler(BaseHandler):
             target_user = UserID.from_string(target_user_id)
             extra_users.append(target_user)
 
-        with PreserveLoggingContext():
-            self.notifier.on_new_room_event(
-                event, event_stream_id, max_stream_id, extra_users=extra_users
-            )
+        self.notifier.on_new_room_event(
+            event, event_stream_id, max_stream_id, extra_users=extra_users
+        )
 
         defer.returnValue(None)
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 7a498af5a2..348056add5 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -612,7 +612,7 @@ class MessageHandler(BaseHandler):
         @defer.inlineCallbacks
         def _notify():
             yield run_on_reactor()
-            yield self.notifier.on_new_room_event(
+            self.notifier.on_new_room_event(
                 event, event_stream_id, max_stream_id,
                 extra_users=extra_users
             )
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 4fda184b7a..48566187ab 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -210,7 +210,6 @@ class Notifier(object):
         """
         self.replication_callbacks.append(cb)
 
-    @preserve_fn
     def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
                           extra_users=[]):
         """ Used by handlers to inform the notifier something has happened
@@ -224,15 +223,13 @@ class Notifier(object):
         until all previous events have been persisted before notifying
         the client streams.
         """
-        with PreserveLoggingContext():
-            self.pending_new_room_events.append((
-                room_stream_id, event, extra_users
-            ))
-            self._notify_pending_new_room_events(max_room_stream_id)
+        self.pending_new_room_events.append((
+            room_stream_id, event, extra_users
+        ))
+        self._notify_pending_new_room_events(max_room_stream_id)
 
-            self.notify_replication()
+        self.notify_replication()
 
-    @preserve_fn
     def _notify_pending_new_room_events(self, max_room_stream_id):
         """Notify for the room events that were queued waiting for a previous
         event to be persisted.
@@ -250,14 +247,16 @@ class Notifier(object):
             else:
                 self._on_new_room_event(event, room_stream_id, extra_users)
 
-    @preserve_fn
     def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
         """Notify any user streams that are interested in this room event"""
         # poke any interested application service.
-        self.appservice_handler.notify_interested_services(room_stream_id)
+        preserve_fn(self.appservice_handler.notify_interested_services)(
+            room_stream_id)
 
         if self.federation_sender:
-            self.federation_sender.notify_new_events(room_stream_id)
+            preserve_fn(self.federation_sender.notify_new_events)(
+                room_stream_id
+            )
 
         if event.type == EventTypes.Member and event.membership == Membership.JOIN:
             self._user_joined_room(event.state_key, event.room_id)
@@ -268,7 +267,6 @@ class Notifier(object):
             rooms=[event.room_id],
         )
 
-    @preserve_fn
     def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
         """ Used to inform listeners that something has happend event wise.
 
@@ -295,7 +293,6 @@ class Notifier(object):
 
                 self.notify_replication()
 
-    @preserve_fn
     def on_new_replication_data(self):
         """Used to inform replication listeners that something has happend
         without waking up any of the normal user event streams"""
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 4db76f18bd..4d88046579 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -17,6 +17,7 @@ import logging
 import re
 
 from synapse.types import UserID
+from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
 from synapse.util.caches.lrucache import LruCache
 
 logger = logging.getLogger(__name__)
@@ -125,6 +126,11 @@ class PushRuleEvaluatorForEvent(object):
         return self._value_cache.get(dotted_key, None)
 
 
+# Caches (glob, word_boundary) -> regex for push. See _glob_matches
+regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
+register_cache("regex_push_cache", regex_cache)
+
+
 def _glob_matches(glob, value, word_boundary=False):
     """Tests if value matches glob.
 
@@ -137,46 +143,63 @@ def _glob_matches(glob, value, word_boundary=False):
     Returns:
         bool
     """
-    try:
-        if IS_GLOB.search(glob):
-            r = re.escape(glob)
-
-            r = r.replace(r'\*', '.*?')
-            r = r.replace(r'\?', '.')
-
-            # handle [abc], [a-z] and [!a-z] style ranges.
-            r = GLOB_REGEX.sub(
-                lambda x: (
-                    '[%s%s]' % (
-                        x.group(1) and '^' or '',
-                        x.group(2).replace(r'\\\-', '-')
-                    )
-                ),
-                r,
-            )
-            if word_boundary:
-                r = r"\b%s\b" % (r,)
-                r = _compile_regex(r)
-
-                return r.search(value)
-            else:
-                r = r + "$"
-                r = _compile_regex(r)
-
-                return r.match(value)
-        elif word_boundary:
-            r = re.escape(glob)
-            r = r"\b%s\b" % (r,)
-            r = _compile_regex(r)
 
-            return r.search(value)
-        else:
-            return value.lower() == glob.lower()
+    try:
+        r = regex_cache.get((glob, word_boundary), None)
+        if not r:
+            r = _glob_to_re(glob, word_boundary)
+            regex_cache[(glob, word_boundary)] = r
+        return r.search(value)
     except re.error:
         logger.warn("Failed to parse glob to regex: %r", glob)
         return False
 
 
+def _glob_to_re(glob, word_boundary):
+    """Generates regex for a given glob.
+
+    Args:
+        glob (string)
+        word_boundary (bool): Whether to match against word boundaries or entire
+            string. Defaults to False.
+
+    Returns:
+        regex object
+    """
+    if IS_GLOB.search(glob):
+        r = re.escape(glob)
+
+        r = r.replace(r'\*', '.*?')
+        r = r.replace(r'\?', '.')
+
+        # handle [abc], [a-z] and [!a-z] style ranges.
+        r = GLOB_REGEX.sub(
+            lambda x: (
+                '[%s%s]' % (
+                    x.group(1) and '^' or '',
+                    x.group(2).replace(r'\\\-', '-')
+                )
+            ),
+            r,
+        )
+        if word_boundary:
+            r = r"\b%s\b" % (r,)
+
+            return re.compile(r, flags=re.IGNORECASE)
+        else:
+            r = "^" + r + "$"
+
+            return re.compile(r, flags=re.IGNORECASE)
+    elif word_boundary:
+        r = re.escape(glob)
+        r = r"\b%s\b" % (r,)
+
+        return re.compile(r, flags=re.IGNORECASE)
+    else:
+        r = "^" + re.escape(glob) + "$"
+        return re.compile(r, flags=re.IGNORECASE)
+
+
 def _flatten_dict(d, prefix=[], result={}):
     for key, value in d.items():
         if isinstance(value, basestring):
@@ -185,16 +208,3 @@ def _flatten_dict(d, prefix=[], result={}):
             _flatten_dict(value, prefix=(prefix + [key]), result=result)
 
     return result
-
-
-regex_cache = LruCache(5000)
-
-
-def _compile_regex(regex_str):
-    r = regex_cache.get(regex_str, None)
-    if r:
-        return r
-
-    r = re.compile(regex_str, flags=re.IGNORECASE)
-    regex_cache[regex_str] = r
-    return r
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 287df94b4f..6835f54e97 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -17,15 +17,12 @@ from twisted.internet import defer
 from synapse.push.presentable_names import (
     calculate_room_name, name_from_member_event
 )
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 
 
 @defer.inlineCallbacks
 def get_badge_count(store, user_id):
-    invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
-        preserve_fn(store.get_invited_rooms_for_user)(user_id),
-        preserve_fn(store.get_rooms_for_user)(user_id),
-    ], consumeErrors=True))
+    invites = yield store.get_invited_rooms_for_user(user_id)
+    joins = yield store.get_rooms_for_user(user_id)
 
     my_receipts_by_room = yield store.get_receipts_for_user(
         user_id, "m.read",
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 03141c623c..c43b30b73a 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -28,7 +28,10 @@ class VoipRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request):
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield self.auth.get_user_by_req(
+            request,
+            self.hs.config.turn_allow_guests
+        )
 
         turnUris = self.hs.config.turn_uris
         turnSecret = self.hs.config.turn_shared_secret
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 31f94bc6e9..6fceb23e26 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -36,7 +36,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request):
-        yield self.auth.get_user_by_req(request)
+        yield self.auth.get_user_by_req(request, allow_guest=True)
 
         protocols = yield self.appservice_handler.get_3pe_protocols()
         defer.returnValue((200, protocols))
@@ -54,7 +54,7 @@ class ThirdPartyProtocolServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, protocol):
-        yield self.auth.get_user_by_req(request)
+        yield self.auth.get_user_by_req(request, allow_guest=True)
 
         protocols = yield self.appservice_handler.get_3pe_protocols(
             only_protocol=protocol,
@@ -77,7 +77,7 @@ class ThirdPartyUserServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, protocol):
-        yield self.auth.get_user_by_req(request)
+        yield self.auth.get_user_by_req(request, allow_guest=True)
 
         fields = request.args
         fields.pop("access_token", None)
@@ -101,7 +101,7 @@ class ThirdPartyLocationServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, protocol):
-        yield self.auth.get_user_by_req(request)
+        yield self.auth.get_user_by_req(request, allow_guest=True)
 
         fields = request.args
         fields.pop("access_token", None)
diff --git a/synapse/types.py b/synapse/types.py
index 9666f9d73f..c87ed813b9 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -216,9 +216,7 @@ class StreamToken(
             return self
 
     def copy_and_replace(self, key, new_value):
-        d = self._asdict()
-        d[key] = new_value
-        return StreamToken(**d)
+        return self._replace(**{key: new_value})
 
 
 StreamToken.START = StreamToken(
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 35380bf8ed..1453faf0ef 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -89,6 +89,11 @@ class ObservableDeferred(object):
         deferred.addCallbacks(callback, errback)
 
     def observe(self):
+        """Observe the underlying deferred.
+
+        Can return either a deferred if the underlying deferred is still pending
+        (or has failed), or the actual value. Callers may need to use maybeDeferred.
+        """
         if not self._result:
             d = defer.Deferred()
 
@@ -101,7 +106,7 @@ class ObservableDeferred(object):
             return d
         else:
             success, res = self._result
-            return defer.succeed(res) if success else defer.fail(res)
+            return res if success else defer.fail(res)
 
     def observers(self):
         return self._observers
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 19595df422..9d0d0be1f9 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -15,12 +15,9 @@
 import logging
 
 from synapse.util.async import ObservableDeferred
-from synapse.util import unwrapFirstError
+from synapse.util import unwrapFirstError, logcontext
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.logcontext import (
-    PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
-)
 
 from . import DEBUG_CACHES, register_cache
 
@@ -227,8 +224,20 @@ class _CacheDescriptorBase(object):
             )
 
         self.num_args = num_args
+
+        # list of the names of the args used as the cache key
         self.arg_names = all_args[1:num_args + 1]
 
+        # self.arg_defaults is a map of arg name to its default value for each
+        # argument that has a default value
+        if arg_spec.defaults:
+            self.arg_defaults = dict(zip(
+                all_args[-len(arg_spec.defaults):],
+                arg_spec.defaults
+            ))
+        else:
+            self.arg_defaults = {}
+
         if "cache_context" in self.arg_names:
             raise Exception(
                 "cache_context arg cannot be included among the cache keys"
@@ -292,18 +301,31 @@ class CacheDescriptor(_CacheDescriptorBase):
             iterable=self.iterable,
         )
 
+        def get_cache_key(args, kwargs):
+            """Given some args/kwargs return a generator that resolves into
+            the cache_key.
+
+            We loop through each arg name, looking up if its in the `kwargs`,
+            otherwise using the next argument in `args`. If there are no more
+            args then we try looking the arg name up in the defaults
+            """
+            pos = 0
+            for nm in self.arg_names:
+                if nm in kwargs:
+                    yield kwargs[nm]
+                elif pos < len(args):
+                    yield args[pos]
+                    pos += 1
+                else:
+                    yield self.arg_defaults[nm]
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
-            # Add temp cache_context so inspect.getcallargs doesn't explode
-            if self.add_cache_context:
-                kwargs["cache_context"] = None
-
-            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+            cache_key = tuple(get_cache_key(args, kwargs))
 
             # Add our own `cache_context` to argument list if the wrapped function
             # has asked for one
@@ -328,11 +350,9 @@ class CacheDescriptor(_CacheDescriptorBase):
                         defer.returnValue(cached_result)
                     observer.addCallback(check_result)
 
-                return preserve_context_over_deferred(observer)
             except KeyError:
                 ret = defer.maybeDeferred(
-                    preserve_context_over_fn,
-                    self.function_to_call,
+                    logcontext.preserve_fn(self.function_to_call),
                     obj, *args, **kwargs
                 )
 
@@ -342,10 +362,14 @@ class CacheDescriptor(_CacheDescriptorBase):
 
                 ret.addErrback(onErr)
 
-                ret = ObservableDeferred(ret, consumeErrors=True)
-                cache.set(cache_key, ret, callback=invalidate_callback)
+                result_d = ObservableDeferred(ret, consumeErrors=True)
+                cache.set(cache_key, result_d, callback=invalidate_callback)
+                observer = result_d.observe()
 
-                return preserve_context_over_deferred(ret.observe())
+            if isinstance(observer, defer.Deferred):
+                return logcontext.make_deferred_yieldable(observer)
+            else:
+                return observer
 
         wrapped.invalidate = cache.invalidate
         wrapped.invalidate_all = cache.invalidate_all
@@ -362,7 +386,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
     """Wraps an existing cache to support bulk fetching of keys.
 
     Given a list of keys it looks in the cache to find any hits, then passes
-    the list of missing keys to the wrapped fucntion.
+    the list of missing keys to the wrapped function.
+
+    Once wrapped, the function returns either a Deferred which resolves to
+    the list of results, or (if all results were cached), just the list of
+    results.
     """
 
     def __init__(self, orig, cached_method_name, list_name, num_args=None,
@@ -433,8 +461,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
                 args_to_call[self.list_name] = missing
 
                 ret_d = defer.maybeDeferred(
-                    preserve_context_over_fn,
-                    self.function_to_call,
+                    logcontext.preserve_fn(self.function_to_call),
                     **args_to_call
                 )
 
@@ -443,8 +470,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
                 # We need to create deferreds for each arg in the list so that
                 # we can insert the new deferred into the cache.
                 for arg in missing:
-                    with PreserveLoggingContext():
-                        observer = ret_d.observe()
+                    observer = ret_d.observe()
                     observer.addCallback(lambda r, arg: r.get(arg, None), arg)
 
                     observer = ObservableDeferred(observer)
@@ -471,7 +497,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
                     results.update(res)
                     return results
 
-                return preserve_context_over_deferred(defer.gatherResults(
+                return logcontext.make_deferred_yieldable(defer.gatherResults(
                     cached_defers.values(),
                     consumeErrors=True,
                 ).addCallback(update_results_dict).addErrback(
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index ff67b1d794..990216145e 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -310,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs):
 def preserve_context_over_deferred(deferred, context=None):
     """Given a deferred wrap it such that any callbacks added later to it will
     be invoked with the current context.
+
+    Deprecated: this almost certainly doesn't do want you want, ie make
+    the deferred follow the synapse logcontext rules: try
+    ``make_deferred_yieldable`` instead.
     """
     if context is None:
         context = LoggingContext.current_context()
@@ -330,12 +334,8 @@ def preserve_fn(f):
         LoggingContext.set_current_context(LoggingContext.sentinel)
         return result
 
-    # XXX: why is this here rather than inside g? surely we want to preserve
-    # the context from the time the function was called, not when it was
-    # wrapped?
-    current = LoggingContext.current_context()
-
     def g(*args, **kwargs):
+        current = LoggingContext.current_context()
         res = f(*args, **kwargs)
         if isinstance(res, defer.Deferred) and not res.called:
             # The function will have reset the context before returning, so
@@ -359,6 +359,25 @@ def preserve_fn(f):
     return g
 
 
+@defer.inlineCallbacks
+def make_deferred_yieldable(deferred):
+    """Given a deferred, make it follow the Synapse logcontext rules:
+
+    If the deferred has completed (or is not actually a Deferred), essentially
+    does nothing (just returns another completed deferred with the
+    result/failure).
+
+    If the deferred has not yet completed, resets the logcontext before
+    returning a deferred. Then, when the deferred completes, restores the
+    current logcontext before running callbacks/errbacks.
+
+    (This is more-or-less the opposite operation to preserve_fn.)
+    """
+    with PreserveLoggingContext():
+        r = yield deferred
+    defer.returnValue(r)
+
+
 # modules to ignore in `logcontext_tracer`
 _to_ignore = [
     "synapse.util.logcontext",
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 31659156ae..c4dd9ae2c7 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
         events ([synapse.events.EventBase]): list of events to filter
     """
     forgotten = yield preserve_context_over_deferred(defer.gatherResults([
-        preserve_fn(store.who_forgot_in_room)(
+        defer.maybeDeferred(
+            preserve_fn(store.who_forgot_in_room),
             room_id,
         )
         for room_id in frozenset(e.room_id for e in events)