| diff --git a/changelog.d/3384.misc b/changelog.d/3384.misc
new file mode 100644
index 0000000000..5d56c876d9
--- /dev/null
+++ b/changelog.d/3384.misc
@@ -0,0 +1 @@
+Rewrite cache list decorator
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
 index f8a07df6b8..861c24809c 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
-            # If we're passed a cache_context then we'll want to call its invalidate()
-            # whenever we are invalidated
+            # If we're passed a cache_context then we'll want to call its
+            # invalidate() whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
             list_args = arg_dict[self.list_name]
 
-            # cached is a dict arg -> deferred, where deferred results in a
-            # 2-tuple (`arg`, `result`)
             results = {}
-            cached_defers = {}
-            missing = []
+
+            def update_results_dict(res, arg):
+                results[arg] = res
+
+            # list of deferreds to wait for
+            cached_defers = []
+
+            missing = set()
 
             # If the cache takes a single arg then that is used as the key,
             # otherwise a tuple is used.
             if num_args == 1:
-                def cache_get(arg):
-                    return cache.get(arg, callback=invalidate_callback)
+                def arg_to_cache_key(arg):
+                    return arg
             else:
-                key = list(keyargs)
+                keylist = list(keyargs)
 
-                def cache_get(arg):
-                    key[self.list_pos] = arg
-                    return cache.get(tuple(key), callback=invalidate_callback)
+                def arg_to_cache_key(arg):
+                    keylist[self.list_pos] = arg
+                    return tuple(keylist)
 
             for arg in list_args:
                 try:
-                    res = cache_get(arg)
-
+                    res = cache.get(arg_to_cache_key(arg),
+                                    callback=invalidate_callback)
                     if not isinstance(res, ObservableDeferred):
                         results[arg] = res
                     elif not res.has_succeeded():
                         res = res.observe()
-                        res.addCallback(lambda r, arg: (arg, r), arg)
-                        cached_defers[arg] = res
+                        res.addCallback(update_results_dict, arg)
+                        cached_defers.append(res)
                     else:
                         results[arg] = res.get_result()
                 except KeyError:
-                    missing.append(arg)
+                    missing.add(arg)
 
             if missing:
+                # we need an observable deferred for each entry in the list,
+                # which we put in the cache. Each deferred resolves with the
+                # relevant result for that key.
+                deferreds_map = {}
+                for arg in missing:
+                    deferred = defer.Deferred()
+                    deferreds_map[arg] = deferred
+                    key = arg_to_cache_key(arg)
+                    observable = ObservableDeferred(deferred)
+                    cache.set(key, observable, callback=invalidate_callback)
+
+                def complete_all(res):
+                    # the wrapped function has completed. It returns a
+                    # a dict. We can now resolve the observable deferreds in
+                    # the cache and update our own result map.
+                    for e in missing:
+                        val = res.get(e, None)
+                        deferreds_map[e].callback(val)
+                        results[e] = val
+
+                def errback(f):
+                    # the wrapped function has failed. Invalidate any cache
+                    # entries we're supposed to be populating, and fail
+                    # their deferreds.
+                    for e in missing:
+                        key = arg_to_cache_key(e)
+                        cache.invalidate(key)
+                        deferreds_map[e].errback(f)
+
+                    # return the failure, to propagate to our caller.
+                    return f
+
                 args_to_call = dict(arg_dict)
-                args_to_call[self.list_name] = missing
+                args_to_call[self.list_name] = list(missing)
 
-                ret_d = defer.maybeDeferred(
+                cached_defers.append(defer.maybeDeferred(
                     logcontext.preserve_fn(self.function_to_call),
                     **args_to_call
-                )
-
-                ret_d = ObservableDeferred(ret_d)
-
-                # We need to create deferreds for each arg in the list so that
-                # we can insert the new deferred into the cache.
-                for arg in missing:
-                    observer = ret_d.observe()
-                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)
-
-                    observer = ObservableDeferred(observer)
-
-                    if num_args == 1:
-                        cache.set(
-                            arg, observer,
-                            callback=invalidate_callback
-                        )
-
-                        def invalidate(f, key):
-                            cache.invalidate(key)
-                            return f
-                        observer.addErrback(invalidate, arg)
-                    else:
-                        key = list(keyargs)
-                        key[self.list_pos] = arg
-                        cache.set(
-                            tuple(key), observer,
-                            callback=invalidate_callback
-                        )
-
-                        def invalidate(f, key):
-                            cache.invalidate(key)
-                            return f
-                        observer.addErrback(invalidate, tuple(key))
-
-                    res = observer.observe()
-                    res.addCallback(lambda r, arg: (arg, r), arg)
-
-                    cached_defers[arg] = res
+                ).addCallbacks(complete_all, errback))
 
             if cached_defers:
-                def update_results_dict(res):
-                    results.update(res)
-                    return results
-
-                return logcontext.make_deferred_yieldable(defer.gatherResults(
-                    list(cached_defers.values()),
+                d = defer.gatherResults(
+                    cached_defers,
                     consumeErrors=True,
-                ).addCallback(update_results_dict).addErrback(
+                ).addCallbacks(
+                    lambda _: results,
                     unwrapFirstError
-                ))
+                )
+                return logcontext.make_deferred_yieldable(d)
             else:
                 return results
 
@@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
     cache.
 
     Args:
-        cache (Cache): The underlying cache to use.
+        cached_method_name (str): The name of the single-item lookup method.
+            This is only used to find the cache to use.
         list_name (str): The name of the argument that is the list to use to
             do batch lookups in the cache.
         num_args (int): Number of arguments to use as the key in the cache
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
 index 8176a7dabd..ca8a7c907f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -273,3 +273,104 @@ class DescriptorTestCase(unittest.TestCase):
         r = yield obj.fn(2, 3)
         self.assertEqual(r, 'chips')
         obj.mock.assert_not_called()
+
+
+class CachedListDescriptorTestCase(unittest.TestCase):
+    @defer.inlineCallbacks
+    def test_cache(self):
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, arg2):
+                pass
+
+            @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+            def list_fn(self, args1, arg2):
+                assert (
+                    logcontext.LoggingContext.current_context().request == "c1"
+                )
+                # we want this to behave like an asynchronous function
+                yield run_on_reactor()
+                assert (
+                    logcontext.LoggingContext.current_context().request == "c1"
+                )
+                defer.returnValue(self.mock(args1, arg2))
+
+        with logcontext.LoggingContext() as c1:
+            c1.request = "c1"
+            obj = Cls()
+            obj.mock.return_value = {10: 'fish', 20: 'chips'}
+            d1 = obj.list_fn([10, 20], 2)
+            self.assertEqual(
+                logcontext.LoggingContext.current_context(),
+                logcontext.LoggingContext.sentinel,
+            )
+            r = yield d1
+            self.assertEqual(
+                logcontext.LoggingContext.current_context(),
+                c1
+            )
+            obj.mock.assert_called_once_with([10, 20], 2)
+            self.assertEqual(r, {10: 'fish', 20: 'chips'})
+            obj.mock.reset_mock()
+
+            # a call with different params should call the mock again
+            obj.mock.return_value = {30: 'peas'}
+            r = yield obj.list_fn([20, 30], 2)
+            obj.mock.assert_called_once_with([30], 2)
+            self.assertEqual(r, {20: 'chips', 30: 'peas'})
+            obj.mock.reset_mock()
+
+            # all the values should now be cached
+            r = yield obj.fn(10, 2)
+            self.assertEqual(r, 'fish')
+            r = yield obj.fn(20, 2)
+            self.assertEqual(r, 'chips')
+            r = yield obj.fn(30, 2)
+            self.assertEqual(r, 'peas')
+            r = yield obj.list_fn([10, 20, 30], 2)
+            obj.mock.assert_not_called()
+            self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})
+
+    @defer.inlineCallbacks
+    def test_invalidate(self):
+        """Make sure that invalidation callbacks are called."""
+        class Cls(object):
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, arg2):
+                pass
+
+            @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
+            def list_fn(self, args1, arg2):
+                # we want this to behave like an asynchronous function
+                yield run_on_reactor()
+                defer.returnValue(self.mock(args1, arg2))
+
+        obj = Cls()
+        invalidate0 = mock.Mock()
+        invalidate1 = mock.Mock()
+
+        # cache miss
+        obj.mock.return_value = {10: 'fish', 20: 'chips'}
+        r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
+        obj.mock.assert_called_once_with([10, 20], 2)
+        self.assertEqual(r1, {10: 'fish', 20: 'chips'})
+        obj.mock.reset_mock()
+
+        # cache hit
+        r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
+        obj.mock.assert_not_called()
+        self.assertEqual(r2, {10: 'fish', 20: 'chips'})
+
+        invalidate0.assert_not_called()
+        invalidate1.assert_not_called()
+
+        # now if we invalidate the keys, both invalidations should get called
+        obj.fn.invalidate((10, 2))
+        invalidate0.assert_called_once()
+        invalidate1.assert_called_once()
 |