diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/__init__.py | 12 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/__init__.py | 24 |
2 files changed, 23 insertions, 13 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 8fb18203dc..ec89f645d4 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -49,15 +49,3 @@ class Storage(object): self.persistence = EventsPersistenceStorage(hs, stores) self.purge_events = PurgeEventsStorage(hs, stores) self.state = StateGroupStorage(hs, stores) - - -def are_all_users_on_domain(txn, database_engine, domain): - sql = database_engine.convert_param_style( - "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" - ) - pat = "%:" + domain - txn.execute(sql, (pat,)) - num_not_matching = txn.fetchall()[0][0] - if num_not_matching == 0: - return True - return False diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 7f5fd81bcf..66f8a9f3a7 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -115,7 +115,17 @@ class DataStore( def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() - self.database_engine = hs.database_engine + self.database_engine = database.engine + + all_users_native = are_all_users_on_domain( + db_conn.cursor(), database.engine, hs.hostname + ) + if not all_users_native: + raise Exception( + "Found users in database not native to %s!\n" + "You cannot changed a synapse server_name after it's been configured" + % (self.hostname,) + ) self._stream_id_gen = StreamIdGenerator( db_conn, @@ -555,3 +565,15 @@ class DataStore( retcols=["name", "password_hash", "is_guest", "admin", "user_type"], desc="search_users", ) + + +def are_all_users_on_domain(txn, database_engine, domain): + sql = database_engine.convert_param_style( + "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" + ) + pat = "%:" + domain + txn.execute(sql, (pat,)) + num_not_matching = txn.fetchall()[0][0] + if num_not_matching == 0: + return True + return False |