summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/9975.misc1
-rw-r--r--synapse/storage/databases/main/devices.py2
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py4
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py13
-rw-r--r--synapse/util/caches/descriptors.py14
-rw-r--r--tests/util/caches/test_descriptors.py17
6 files changed, 31 insertions, 20 deletions
diff --git a/changelog.d/9975.misc b/changelog.d/9975.misc
new file mode 100644
index 0000000000..28b1e40c2b
--- /dev/null
+++ b/changelog.d/9975.misc
@@ -0,0 +1 @@
+Minor enhancements to the `@cachedList` descriptor.
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index c9346de316..a1f98b7e38 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -665,7 +665,7 @@ class DeviceWorkerStore(SQLBaseStore):
         cached_method_name="get_device_list_last_stream_id_for_remote",
         list_name="user_ids",
     )
-    async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
+    async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
         rows = await self.db_pool.simple_select_many_batch(
             table="device_lists_remote_extremeties",
             column="user_id",
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 398d6b6acb..9ba5778a88 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -473,7 +473,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         num_args=1,
     )
     async def _get_bare_e2e_cross_signing_keys_bulk(
-        self, user_ids: List[str]
+        self, user_ids: Iterable[str]
     ) -> Dict[str, Dict[str, dict]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
@@ -497,7 +497,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
     def _get_bare_e2e_cross_signing_keys_bulk_txn(
         self,
         txn: Connection,
-        user_ids: List[str],
+        user_ids: Iterable[str],
     ) -> Dict[str, Dict[str, dict]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index acf6b2fb64..1ecdd40c38 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Dict, Iterable
+
 from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
 
@@ -37,21 +39,16 @@ class UserErasureWorkerStore(SQLBaseStore):
         return bool(result)
 
     @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
-    async def are_users_erased(self, user_ids):
+    async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
         """
         Checks which users in a list have requested erasure
 
         Args:
-            user_ids (iterable[str]): full user id to check
+            user_ids: full user ids to check
 
         Returns:
-            dict[str, bool]:
-                for each user, whether the user has requested erasure.
+            for each user, whether the user has requested erasure.
         """
-        # this serves the dual purpose of (a) making sure we can do len and
-        # iterate it multiple times, and (b) avoiding duplicates.
-        user_ids = tuple(set(user_ids))
-
         rows = await self.db_pool.simple_select_many_batch(
             table="erased_users",
             column="user_id",
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index ac4a078b26..3a4d027095 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -322,8 +322,8 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
 class DeferredCacheListDescriptor(_CacheDescriptorBase):
     """Wraps an existing cache to support bulk fetching of keys.
 
-    Given a list of keys it looks in the cache to find any hits, then passes
-    the list of missing keys to the wrapped function.
+    Given an iterable of keys it looks in the cache to find any hits, then passes
+    the tuple of missing keys to the wrapped function.
 
     Once wrapped, the function returns a Deferred which resolves to the list
     of results.
@@ -437,7 +437,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                     return f
 
                 args_to_call = dict(arg_dict)
-                args_to_call[self.list_name] = list(missing)
+                # copy the missing set before sending it to the callee, to guard against
+                # modification.
+                args_to_call[self.list_name] = tuple(missing)
 
                 cached_defers.append(
                     defer.maybeDeferred(
@@ -522,14 +524,14 @@ def cachedList(
 
     Used to do batch lookups for an already created cache. A single argument
     is specified as a list that is iterated through to lookup keys in the
-    original cache. A new list consisting of the keys that weren't in the cache
-    get passed to the original function, the result of which is stored in the
+    original cache. A new tuple consisting of the (deduplicated) keys that weren't in
+    the cache gets passed to the original function, the result of which is stored in the
     cache.
 
     Args:
         cached_method_name: The name of the single-item lookup method.
             This is only used to find the cache to use.
-        list_name: The name of the argument that is the list to use to
+        list_name: The name of the argument that is the iterable to use to
             do batch lookups in the cache.
         num_args: Number of arguments to use as the key in the cache
             (including list_name). Defaults to all named parameters.
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 178ac8a68c..bbbc276697 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -666,18 +666,20 @@ class CachedListDescriptorTestCase(unittest.TestCase):
         with LoggingContext("c1") as c1:
             obj = Cls()
             obj.mock.return_value = {10: "fish", 20: "chips"}
+
+            # start the lookup off
             d1 = obj.list_fn([10, 20], 2)
             self.assertEqual(current_context(), SENTINEL_CONTEXT)
             r = yield d1
             self.assertEqual(current_context(), c1)
-            obj.mock.assert_called_once_with([10, 20], 2)
+            obj.mock.assert_called_once_with((10, 20), 2)
             self.assertEqual(r, {10: "fish", 20: "chips"})
             obj.mock.reset_mock()
 
             # a call with different params should call the mock again
             obj.mock.return_value = {30: "peas"}
             r = yield obj.list_fn([20, 30], 2)
-            obj.mock.assert_called_once_with([30], 2)
+            obj.mock.assert_called_once_with((30,), 2)
             self.assertEqual(r, {20: "chips", 30: "peas"})
             obj.mock.reset_mock()
 
@@ -692,6 +694,15 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             obj.mock.assert_not_called()
             self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"})
 
+            # we should also be able to use a (single-use) iterable, and should
+            # deduplicate the keys
+            obj.mock.reset_mock()
+            obj.mock.return_value = {40: "gravy"}
+            iterable = (x for x in [10, 40, 40])
+            r = yield obj.list_fn(iterable, 2)
+            obj.mock.assert_called_once_with((40,), 2)
+            self.assertEqual(r, {10: "fish", 40: "gravy"})
+
     @defer.inlineCallbacks
     def test_invalidate(self):
         """Make sure that invalidation callbacks are called."""
@@ -717,7 +728,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
         # cache miss
         obj.mock.return_value = {10: "fish", 20: "chips"}
         r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
-        obj.mock.assert_called_once_with([10, 20], 2)
+        obj.mock.assert_called_once_with((10, 20), 2)
         self.assertEqual(r1, {10: "fish", 20: "chips"})
         obj.mock.reset_mock()