summary refs log tree commit diff
path: root/synapse/storage/data_stores
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-12-06 13:09:40 +0000
committerErik Johnston <erik@matrix.org>2019-12-06 13:43:40 +0000
commitd64bb32a73761ad55f53152756b8e0c10e1de9b0 (patch)
tree05d724d3d3eafd1842e6726dfde21ab2e8c365cb /synapse/storage/data_stores
parentChange DataStores to accept 'database' param. (diff)
downloadsynapse-d64bb32a73761ad55f53152756b8e0c10e1de9b0.tar.xz
Move are_all_users_on_domain checks to main data store.
Diffstat (limited to 'synapse/storage/data_stores')
-rw-r--r--synapse/storage/data_stores/main/__init__.py24
1 files changed, 23 insertions, 1 deletions
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 7f5fd81bcf..66f8a9f3a7 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -115,7 +115,17 @@ class DataStore(
     def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
-        self.database_engine = hs.database_engine
+        self.database_engine = database.engine
+
+        all_users_native = are_all_users_on_domain(
+            db_conn.cursor(), database.engine, hs.hostname
+        )
+        if not all_users_native:
+            raise Exception(
+                "Found users in database not native to %s!\n"
+                "You cannot changed a synapse server_name after it's been configured"
+                % (self.hostname,)
+            )
 
         self._stream_id_gen = StreamIdGenerator(
             db_conn,
@@ -555,3 +565,15 @@ class DataStore(
             retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
             desc="search_users",
         )
+
+
+def are_all_users_on_domain(txn, database_engine, domain):
+    sql = database_engine.convert_param_style(
+        "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
+    )
+    pat = "%:" + domain
+    txn.execute(sql, (pat,))
+    num_not_matching = txn.fetchall()[0][0]
+    if num_not_matching == 0:
+        return True
+    return False