summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11313.misc1
-rw-r--r--mypy.ini4
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py9
3 files changed, 9 insertions, 5 deletions
diff --git a/changelog.d/11313.misc b/changelog.d/11313.misc
new file mode 100644
index 0000000000..86594a332d
--- /dev/null
+++ b/changelog.d/11313.misc
@@ -0,0 +1 @@
+Add type hints to storage classes.
diff --git a/mypy.ini b/mypy.ini
index 3b7e1eb708..f0af4ab289 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -53,7 +53,6 @@ exclude = (?x)
    |synapse/storage/databases/main/stats.py
    |synapse/storage/databases/main/transactions.py
    |synapse/storage/databases/main/user_directory.py
-   |synapse/storage/databases/main/user_erasure_store.py
    |synapse/storage/schema/
 
    |tests/api/test_auth.py
@@ -184,6 +183,9 @@ disallow_untyped_defs = True
 [mypy-synapse.storage.databases.main.room_batch]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.user_erasure_store]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.util.*]
 disallow_untyped_defs = True
 
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index 1ecdd40c38..f79006533f 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -14,11 +14,12 @@
 
 from typing import Dict, Iterable
 
-from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.databases.main import CacheInvalidationWorkerStore
 from synapse.util.caches.descriptors import cached, cachedList
 
 
-class UserErasureWorkerStore(SQLBaseStore):
+class UserErasureWorkerStore(CacheInvalidationWorkerStore):
     @cached()
     async def is_user_erased(self, user_id: str) -> bool:
         """
@@ -69,7 +70,7 @@ class UserErasureStore(UserErasureWorkerStore):
             user_id: full user_id to be erased
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             # first check if they are already in the list
             txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
             if txn.fetchone():
@@ -89,7 +90,7 @@ class UserErasureStore(UserErasureWorkerStore):
             user_id: full user_id to be un-erased
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             # first check if they are already in the list
             txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
             if not txn.fetchone():