diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index c9346de316..a1f98b7e38 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -665,7 +665,7 @@ class DeviceWorkerStore(SQLBaseStore):
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
)
- async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
+ async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 398d6b6acb..9ba5778a88 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -473,7 +473,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
num_args=1,
)
async def _get_bare_e2e_cross_signing_keys_bulk(
- self, user_ids: List[str]
+ self, user_ids: Iterable[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
@@ -497,7 +497,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
txn: Connection,
- user_ids: List[str],
+ user_ids: Iterable[str],
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index acf6b2fb64..1ecdd40c38 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, Iterable
+
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
@@ -37,21 +39,16 @@ class UserErasureWorkerStore(SQLBaseStore):
return bool(result)
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
- async def are_users_erased(self, user_ids):
+ async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
"""
Checks which users in a list have requested erasure
Args:
- user_ids (iterable[str]): full user id to check
+ user_ids: full user ids to check
Returns:
- dict[str, bool]:
- for each user, whether the user has requested erasure.
+ for each user, whether the user has requested erasure.
"""
- # this serves the dual purpose of (a) making sure we can do len and
- # iterate it multiple times, and (b) avoiding duplicates.
- user_ids = tuple(set(user_ids))
-
rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index ac4a078b26..3a4d027095 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -322,8 +322,8 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
class DeferredCacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys.
- Given a list of keys it looks in the cache to find any hits, then passes
- the list of missing keys to the wrapped function.
+ Given an iterable of keys it looks in the cache to find any hits, then passes
+ the tuple of missing keys to the wrapped function.
Once wrapped, the function returns a Deferred which resolves to the list
of results.
@@ -437,7 +437,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
return f
args_to_call = dict(arg_dict)
- args_to_call[self.list_name] = list(missing)
+ # copy the missing set before sending it to the callee, to guard against
+ # modification.
+ args_to_call[self.list_name] = tuple(missing)
cached_defers.append(
defer.maybeDeferred(
@@ -522,14 +524,14 @@ def cachedList(
Used to do batch lookups for an already created cache. A single argument
is specified as a list that is iterated through to lookup keys in the
- original cache. A new list consisting of the keys that weren't in the cache
- get passed to the original function, the result of which is stored in the
+ original cache. A new tuple consisting of the (deduplicated) keys that weren't in
+ the cache gets passed to the original function, the result of which is stored in the
cache.
Args:
cached_method_name: The name of the single-item lookup method.
This is only used to find the cache to use.
- list_name: The name of the argument that is the list to use to
+ list_name: The name of the argument that is the iterable to use to
do batch lookups in the cache.
num_args: Number of arguments to use as the key in the cache
(including list_name). Defaults to all named parameters.
|