summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-10-21 10:41:12 +0100
committerErik Johnston <erik@matrix.org>2020-10-21 10:41:12 +0100
commitf0bcf6f578ce22207a9ba5ed505b8ecfebc35638 (patch)
treef46a29865c23966f930fa61f5ccacbd508dcbb20
parentFormat SQL (diff)
downloadsynapse-f0bcf6f578ce22207a9ba5ed505b8ecfebc35638.tar.xz
Add registration store to mypy
-rw-r--r--mypy.ini1
-rw-r--r--synapse/storage/databases/main/__init__.py1
-rw-r--r--synapse/storage/databases/main/registration.py92
3 files changed, 49 insertions, 45 deletions
diff --git a/mypy.ini b/mypy.ini
index b5db54ee3b..516361f2f5 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -55,6 +55,7 @@ files =
   synapse/spam_checker_api,
   synapse/state,
   synapse/storage/databases/main/events.py,
+  synapse/storage/databases/main/registration.py,
   synapse/storage/databases/main/stream.py,
   synapse/storage/databases/main/ui_auth.py,
   synapse/storage/database.py,
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 9b16f45f3e..43660ec4fb 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -146,7 +146,6 @@ class DataStore(
             db_conn, "e2e_cross_signing_keys", "stream_id"
         )
 
-        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
         self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 80d9606783..d856b76952 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -21,9 +21,11 @@ from typing import Any, Dict, List, Optional, Tuple
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.stats import StatsStore
 from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import IdGenerator
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import UserID
 from synapse.util.caches.descriptors import cached
@@ -33,7 +35,7 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
 logger = logging.getLogger(__name__)
 
 
-class RegistrationWorkerStore(SQLBaseStore):
+class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
 
@@ -1020,13 +1022,56 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
 
         return 1
 
+    async def set_user_deactivated_status(
+        self, user_id: str, deactivated: bool
+    ) -> None:
+        """Set the `deactivated` property for the provided user to the provided value.
+
+        Args:
+            user_id: The ID of the user to set the status for.
+            deactivated: The value to set for `deactivated`.
+        """
+
+        await self.db_pool.runInteraction(
+            "set_user_deactivated_status",
+            self.set_user_deactivated_status_txn,
+            user_id,
+            deactivated,
+        )
+
+    def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
+        self.db_pool.simple_update_one_txn(
+            txn=txn,
+            table="users",
+            keyvalues={"name": user_id},
+            updatevalues={"deactivated": 1 if deactivated else 0},
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_user_deactivated_status, (user_id,)
+        )
+        txn.call_after(self.is_guest.invalidate, (user_id,))
 
-class RegistrationStore(RegistrationBackgroundUpdateStore):
+    @cached()
+    async def is_guest(self, user_id: str) -> bool:
+        res = await self.db_pool.simple_select_one_onecol(
+            table="users",
+            keyvalues={"name": user_id},
+            retcol="is_guest",
+            allow_none=True,
+            desc="is_guest",
+        )
+
+        return res if res else False
+
+
+class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
 
         self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
 
+        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+
     async def add_access_token_to_user(
         self,
         user_id: str,
@@ -1378,18 +1423,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         await self.db_pool.runInteraction("delete_access_token", f)
 
-    @cached()
-    async def is_guest(self, user_id: str) -> bool:
-        res = await self.db_pool.simple_select_one_onecol(
-            table="users",
-            keyvalues={"name": user_id},
-            retcol="is_guest",
-            allow_none=True,
-            desc="is_guest",
-        )
-
-        return res if res else False
-
     async def add_user_pending_deactivation(self, user_id: str) -> None:
         """
         Adds a user to the table of users who need to be parted from all the rooms they're
@@ -1551,35 +1584,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             start_or_continue_validation_session_txn,
         )
 
-    async def set_user_deactivated_status(
-        self, user_id: str, deactivated: bool
-    ) -> None:
-        """Set the `deactivated` property for the provided user to the provided value.
-
-        Args:
-            user_id: The ID of the user to set the status for.
-            deactivated: The value to set for `deactivated`.
-        """
-
-        await self.db_pool.runInteraction(
-            "set_user_deactivated_status",
-            self.set_user_deactivated_status_txn,
-            user_id,
-            deactivated,
-        )
-
-    def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
-        self.db_pool.simple_update_one_txn(
-            txn=txn,
-            table="users",
-            keyvalues={"name": user_id},
-            updatevalues={"deactivated": 1 if deactivated else 0},
-        )
-        self._invalidate_cache_and_stream(
-            txn, self.get_user_deactivated_status, (user_id,)
-        )
-        txn.call_after(self.is_guest.invalidate, (user_id,))
-
 
 def find_max_generated_user_id_localpart(cur: Cursor) -> int:
     """