summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-02-24 10:13:53 +0000
committerGitHub <noreply@github.com>2021-02-24 10:13:53 +0000
commit0b5c967813a8e5f3338b6ecd5f3742e91b8a100b (patch)
tree8001d34fb268c378624382d13a94c6e2c0b047e4 /synapse/storage/databases/main/registration.py
parentAdd a comment about systemd-python. (#9464) (diff)
downloadsynapse-0b5c967813a8e5f3338b6ecd5f3742e91b8a100b.tar.xz
Refactor to ensure we call check_consistency (#9470)
The idea here is to stop people forgetting to call `check_consistency`. Folks can still just pass in `None` to the new args in `build_sequence_generator`, but hopefully they won't.
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py19
1 files changed, 16 insertions, 3 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index d5b5507815..61a7556e56 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -23,7 +23,7 @@ import attr
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.stats import StatsStore
 from synapse.storage.types import Connection, Cursor
@@ -70,7 +70,12 @@ class TokenLookupResult:
 
 
 class RegistrationWorkerStore(CacheInvalidationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
@@ -79,9 +84,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         # call `find_max_generated_user_id_localpart` each time, which is
         # expensive if there are many entries.
         self._user_id_seq = build_sequence_generator(
+            db_conn,
             database.engine,
             find_max_generated_user_id_localpart,
             "user_id_seq",
+            table=None,
+            id_column=None,
         )
 
         self._account_validity = hs.config.account_validity
@@ -1036,7 +1044,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._clock = hs.get_clock()