summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py118
1 files changed, 62 insertions, 56 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 1a859352b6..6c5b29288a 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -787,13 +787,14 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
 
-class RegistrationStore(
+class RegistrationBackgroundUpdateStore(
     RegistrationWorkerStore, background_updates.BackgroundUpdateStore
 ):
     def __init__(self, db_conn, hs):
-        super(RegistrationStore, self).__init__(db_conn, hs)
+        super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
 
         self.clock = hs.get_clock()
+        self.config = hs.config
 
         self.register_background_index_update(
             "access_tokens_device_index",
@@ -809,8 +810,6 @@ class RegistrationStore(
             columns=["creation_ts"],
         )
 
-        self._account_validity = hs.config.account_validity
-
         # we no longer use refresh tokens, but it's possible that some people
         # might have a background update queued to build this index. Just
         # clear the background update.
@@ -824,17 +823,6 @@ class RegistrationStore(
             "users_set_deactivated_flag", self._background_update_set_deactivated_flag
         )
 
-        # Create a background job for culling expired 3PID validity tokens
-        def start_cull():
-            # run as a background process to make sure that the database transactions
-            # have a logcontext to report to
-            return run_as_background_process(
-                "cull_expired_threepid_validation_tokens",
-                self.cull_expired_threepid_validation_tokens,
-            )
-
-        hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
-
     @defer.inlineCallbacks
     def _background_update_set_deactivated_flag(self, progress, batch_size):
         """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
@@ -897,6 +885,54 @@ class RegistrationStore(
         return nb_processed
 
     @defer.inlineCallbacks
+    def _bg_user_threepids_grandfather(self, progress, batch_size):
+        """We now track which identity servers a user binds their 3PID to, so
+        we need to handle the case of existing bindings where we didn't track
+        this.
+
+        We do this by grandfathering in existing user threepids assuming that
+        they used one of the server configured trusted identity servers.
+        """
+        id_servers = set(self.config.trusted_third_party_id_servers)
+
+        def _bg_user_threepids_grandfather_txn(txn):
+            sql = """
+                INSERT INTO user_threepid_id_server
+                    (user_id, medium, address, id_server)
+                SELECT user_id, medium, address, ?
+                FROM user_threepids
+            """
+
+            txn.executemany(sql, [(id_server,) for id_server in id_servers])
+
+        if id_servers:
+            yield self.runInteraction(
+                "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
+            )
+
+        yield self._end_background_update("user_threepids_grandfather")
+
+        return 1
+
+
+class RegistrationStore(RegistrationBackgroundUpdateStore):
+    def __init__(self, db_conn, hs):
+        super(RegistrationStore, self).__init__(db_conn, hs)
+
+        self._account_validity = hs.config.account_validity
+
+        # Create a background job for culling expired 3PID validity tokens
+        def start_cull():
+            # run as a background process to make sure that the database transactions
+            # have a logcontext to report to
+            return run_as_background_process(
+                "cull_expired_threepid_validation_tokens",
+                self.cull_expired_threepid_validation_tokens,
+            )
+
+        hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
+
+    @defer.inlineCallbacks
     def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
         """Adds an access token for the given user.
 
@@ -1244,36 +1280,6 @@ class RegistrationStore(
             desc="get_users_pending_deactivation",
         )
 
-    @defer.inlineCallbacks
-    def _bg_user_threepids_grandfather(self, progress, batch_size):
-        """We now track which identity servers a user binds their 3PID to, so
-        we need to handle the case of existing bindings where we didn't track
-        this.
-
-        We do this by grandfathering in existing user threepids assuming that
-        they used one of the server configured trusted identity servers.
-        """
-        id_servers = set(self.config.trusted_third_party_id_servers)
-
-        def _bg_user_threepids_grandfather_txn(txn):
-            sql = """
-                INSERT INTO user_threepid_id_server
-                    (user_id, medium, address, id_server)
-                SELECT user_id, medium, address, ?
-                FROM user_threepids
-            """
-
-            txn.executemany(sql, [(id_server,) for id_server in id_servers])
-
-        if id_servers:
-            yield self.runInteraction(
-                "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
-            )
-
-        yield self._end_background_update("user_threepids_grandfather")
-
-        return 1
-
     def validate_threepid_session(self, session_id, client_secret, token, current_ts):
         """Attempt to validate a threepid session using a token
 
@@ -1465,17 +1471,6 @@ class RegistrationStore(
             self.clock.time_msec(),
         )
 
-    def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
-        self._simple_update_one_txn(
-            txn=txn,
-            table="users",
-            keyvalues={"name": user_id},
-            updatevalues={"deactivated": 1 if deactivated else 0},
-        )
-        self._invalidate_cache_and_stream(
-            txn, self.get_user_deactivated_status, (user_id,)
-        )
-
     @defer.inlineCallbacks
     def set_user_deactivated_status(self, user_id, deactivated):
         """Set the `deactivated` property for the provided user to the provided value.
@@ -1491,3 +1486,14 @@ class RegistrationStore(
             user_id,
             deactivated,
         )
+
+    def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
+        self._simple_update_one_txn(
+            txn=txn,
+            table="users",
+            keyvalues={"name": user_id},
+            updatevalues={"deactivated": 1 if deactivated else 0},
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_user_deactivated_status, (user_id,)
+        )