diff options
Diffstat (limited to 'synapse/storage/__init__.py')
-rw-r--r-- | synapse/storage/__init__.py | 60 |
1 files changed, 58 insertions, 2 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 1cede2809d..66658f6721 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -47,6 +47,23 @@ import os logger = logging.getLogger(__name__) + +SCHEMAS = [ + "transactions", + "pdu", + "users", + "profiles", + "presence", + "im", + "room_aliases", +] + + +# Remember to update this number every time an incompatible change is made to +# database schema files, so the users will be informed on server restarts. +SCHEMA_VERSION = 3 + + class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying something went wrong. @@ -78,7 +95,7 @@ class DataStore(RoomMemberStore, RoomStore, stream_ordering = self.min_token try: - yield self._db_pool.runInteraction( + yield self.runInteraction( self._persist_pdu_event_txn, pdu=pdu, event=event, @@ -291,7 +308,7 @@ class DataStore(RoomMemberStore, RoomStore, prev_state_pdu=prev_state_pdu, ) - return self._db_pool.runInteraction(_snapshot) + return self.runInteraction(_snapshot) class Snapshot(object): @@ -361,3 +378,42 @@ def read_schema(schema): """ with open(schema_path(schema)) as schema_file: return schema_file.read() + + +def prepare_database(db_conn): + """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we + don't have to worry about overwriting existing content. + """ + c = db_conn.cursor() + c.execute("PRAGMA user_version") + row = c.fetchone() + + if row and row[0]: + user_version = row[0] + + if user_version > SCHEMA_VERSION: + raise ValueError("Cannot use this database as it is too " + + "new for the server to understand" + ) + elif user_version < SCHEMA_VERSION: + logging.info("Upgrading database from version %d", + user_version + ) + + # Run every version since after the current version. + for v in range(user_version + 1, SCHEMA_VERSION + 1): + sql_script = read_schema("delta/v%d" % (v)) + c.executescript(sql_script) + + db_conn.commit() + + else: + for sql_loc in SCHEMAS: + sql_script = read_schema(sql_loc) + + c.executescript(sql_script) + db_conn.commit() + c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION) + + c.close() + |