summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-05-14 11:12:36 +0100
committerGitHub <noreply@github.com>2021-05-14 11:12:36 +0100
commit5090f26b636bf4439575767a2272d033fb33b2d5 (patch)
tree11ed4623449a07d80598ceadcc923ddfa6dd63a2 /synapse/storage
parentRemove unnecessary SystemRandom from SQLBaseStore (#9987) (diff)
downloadsynapse-5090f26b636bf4439575767a2272d033fb33b2d5.tar.xz
Minor `@cachedList` enhancements (#9975)
- use a tuple rather than a list for the iterable that is passed into the
  wrapped function, for performance

- test that we can pass an iterable and that keys are correctly deduped.
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/devices.py2
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py4
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py13
3 files changed, 8 insertions, 11 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index c9346de316..a1f98b7e38 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -665,7 +665,7 @@ class DeviceWorkerStore(SQLBaseStore):
         cached_method_name="get_device_list_last_stream_id_for_remote",
         list_name="user_ids",
     )
-    async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
+    async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
         rows = await self.db_pool.simple_select_many_batch(
             table="device_lists_remote_extremeties",
             column="user_id",
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 398d6b6acb..9ba5778a88 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -473,7 +473,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         num_args=1,
     )
     async def _get_bare_e2e_cross_signing_keys_bulk(
-        self, user_ids: List[str]
+        self, user_ids: Iterable[str]
     ) -> Dict[str, Dict[str, dict]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
@@ -497,7 +497,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
     def _get_bare_e2e_cross_signing_keys_bulk_txn(
         self,
         txn: Connection,
-        user_ids: List[str],
+        user_ids: Iterable[str],
     ) -> Dict[str, Dict[str, dict]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index acf6b2fb64..1ecdd40c38 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Dict, Iterable
+
 from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
 
@@ -37,21 +39,16 @@ class UserErasureWorkerStore(SQLBaseStore):
         return bool(result)
 
     @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
-    async def are_users_erased(self, user_ids):
+    async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
         """
         Checks which users in a list have requested erasure
 
         Args:
-            user_ids (iterable[str]): full user id to check
+            user_ids: full user ids to check
 
         Returns:
-            dict[str, bool]:
-                for each user, whether the user has requested erasure.
+            for each user, whether the user has requested erasure.
         """
-        # this serves the dual purpose of (a) making sure we can do len and
-        # iterate it multiple times, and (b) avoiding duplicates.
-        user_ids = tuple(set(user_ids))
-
         rows = await self.db_pool.simple_select_many_batch(
             table="erased_users",
             column="user_id",