summary refs log tree commit diff
path: root/synapse/storage/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/__init__.py')
-rw-r--r--synapse/storage/__init__.py146
1 files changed, 112 insertions, 34 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index f4dec70393..61215bbc7b 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 15
+SCHEMA_VERSION = 16
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
@@ -104,14 +104,16 @@ class DataStore(RoomMemberStore, RoomStore,
 
         self.client_ip_last_seen.prefill(*key + (now,))
 
-        yield self._simple_insert(
+        yield self._simple_upsert(
             "user_ips",
-            {
-                "user": user.to_string(),
+            keyvalues={
+                "user_id": user.to_string(),
                 "access_token": access_token,
-                "device_id": device_id,
                 "ip": ip,
                 "user_agent": user_agent,
+            },
+            values={
+                "device_id": device_id,
                 "last_seen": now,
             },
             desc="insert_client_ip",
@@ -120,7 +122,7 @@ class DataStore(RoomMemberStore, RoomStore,
     def get_user_ip_and_agents(self, user):
         return self._simple_select_list(
             table="user_ips",
-            keyvalues={"user": user.to_string()},
+            keyvalues={"user_id": user.to_string()},
             retcols=[
                 "device_id", "access_token", "ip", "user_agent", "last_seen"
             ],
@@ -148,21 +150,23 @@ class UpgradeDatabaseException(PrepareDatabaseException):
     pass
 
 
-def prepare_database(db_conn):
+def prepare_database(db_conn, database_engine):
     """Prepares a database for usage. Will either create all necessary tables
     or upgrade from an older schema version.
     """
     try:
         cur = db_conn.cursor()
-        version_info = _get_or_create_schema_state(cur)
+        version_info = _get_or_create_schema_state(cur, database_engine)
 
         if version_info:
             user_version, delta_files, upgraded = version_info
-            _upgrade_existing_database(cur, user_version, delta_files, upgraded)
+            _upgrade_existing_database(
+                cur, user_version, delta_files, upgraded, database_engine
+            )
         else:
-            _setup_new_database(cur)
+            _setup_new_database(cur, database_engine)
 
-        cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
+        # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
 
         cur.close()
         db_conn.commit()
@@ -171,7 +175,7 @@ def prepare_database(db_conn):
         raise
 
 
-def _setup_new_database(cur):
+def _setup_new_database(cur, database_engine):
     """Sets up the database by finding a base set of "full schemas" and then
     applying any necessary deltas.
 
@@ -225,31 +229,30 @@ def _setup_new_database(cur):
 
     directory_entries = os.listdir(sql_dir)
 
-    sql_script = "BEGIN TRANSACTION;\n"
     for filename in fnmatch.filter(directory_entries, "*.sql"):
         sql_loc = os.path.join(sql_dir, filename)
         logger.debug("Applying schema %s", sql_loc)
-        sql_script += read_schema(sql_loc)
-        sql_script += "\n"
-    sql_script += "COMMIT TRANSACTION;"
-    cur.executescript(sql_script)
+        executescript(cur, sql_loc)
 
     cur.execute(
-        "INSERT OR REPLACE INTO schema_version (version, upgraded)"
-        " VALUES (?,?)",
-        (max_current_ver, False)
+        database_engine.convert_param_style(
+            "INSERT INTO schema_version (version, upgraded)"
+            " VALUES (?,?)"
+        ),
+        (max_current_ver, False,)
     )
 
     _upgrade_existing_database(
         cur,
         current_version=max_current_ver,
         applied_delta_files=[],
-        upgraded=False
+        upgraded=False,
+        database_engine=database_engine,
     )
 
 
 def _upgrade_existing_database(cur, current_version, applied_delta_files,
-                               upgraded):
+                               upgraded, database_engine):
     """Upgrades an existing database.
 
     Delta files can either be SQL stored in *.sql files, or python modules
@@ -305,6 +308,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
     if not upgraded:
         start_ver += 1
 
+    logger.debug("applied_delta_files: %s", applied_delta_files)
+
     for v in range(start_ver, SCHEMA_VERSION + 1):
         logger.debug("Upgrading schema to v%d", v)
 
@@ -321,6 +326,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
         directory_entries.sort()
         for file_name in directory_entries:
             relative_path = os.path.join(str(v), file_name)
+            logger.debug("Found file: %s", relative_path)
             if relative_path in applied_delta_files:
                 continue
 
@@ -342,9 +348,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
                 module.run_upgrade(cur)
             elif ext == ".sql":
                 # A plain old .sql file, just read and execute it
-                delta_schema = read_schema(absolute_path)
                 logger.debug("Applying schema %s", relative_path)
-                cur.executescript(delta_schema)
+                executescript(cur, absolute_path)
             else:
                 # Not a valid delta file.
                 logger.warn(
@@ -356,24 +361,82 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
 
             # Mark as done.
             cur.execute(
-                "INSERT INTO applied_schema_deltas (version, file)"
-                " VALUES (?,?)",
+                database_engine.convert_param_style(
+                    "INSERT INTO applied_schema_deltas (version, file)"
+                    " VALUES (?,?)",
+                ),
                 (v, relative_path)
             )
 
             cur.execute(
-                "INSERT OR REPLACE INTO schema_version (version, upgraded)"
-                " VALUES (?,?)",
+                database_engine.convert_param_style(
+                    "REPLACE INTO schema_version (version, upgraded)"
+                    " VALUES (?,?)",
+                ),
                 (v, True)
             )
 
 
-def _get_or_create_schema_state(txn):
+def get_statements(f):
+    statement_buffer = ""
+    in_comment = False  # If we're in a /* ... */ style comment
+
+    for line in f:
+        line = line.strip()
+
+        if in_comment:
+            # Check if this line contains an end to the comment
+            comments = line.split("*/", 1)
+            if len(comments) == 1:
+                continue
+            line = comments[1]
+            in_comment = False
+
+        # Remove inline block comments
+        line = re.sub(r"/\*.*\*/", " ", line)
+
+        # Does this line start a comment?
+        comments = line.split("/*", 1)
+        if len(comments) > 1:
+            line = comments[0]
+            in_comment = True
+
+        # Deal with line comments
+        line = line.split("--", 1)[0]
+        line = line.split("//", 1)[0]
+
+        # Find *all* semicolons. We need to treat first and last entry
+        # specially.
+        statements = line.split(";")
+
+        # We must prepend statement_buffer to the first statement
+        first_statement = "%s %s" % (
+            statement_buffer.strip(),
+            statements[0].strip()
+        )
+        statements[0] = first_statement
+
+        # Every entry, except the last, is a full statement
+        for statement in statements[:-1]:
+            yield statement.strip()
+
+        # The last entry did *not* end in a semicolon, so we store it for the
+        # next semicolon we find
+        statement_buffer = statements[-1].strip()
+
+
+def executescript(txn, schema_path):
+    with open(schema_path, 'r') as f:
+        for statement in get_statements(f):
+            txn.execute(statement)
+
+
+def _get_or_create_schema_state(txn, database_engine):
+    # Bluntly try creating the schema_version tables.
     schema_path = os.path.join(
         dir_path, "schema", "schema_version.sql",
     )
-    create_schema = read_schema(schema_path)
-    txn.executescript(create_schema)
+    executescript(txn, schema_path)
 
     txn.execute("SELECT version, upgraded FROM schema_version")
     row = txn.fetchone()
@@ -382,10 +445,13 @@ def _get_or_create_schema_state(txn):
 
     if current_version:
         txn.execute(
-            "SELECT file FROM applied_schema_deltas WHERE version >= ?",
+            database_engine.convert_param_style(
+                "SELECT file FROM applied_schema_deltas WHERE version >= ?"
+            ),
             (current_version,)
         )
-        return current_version, txn.fetchall(), upgraded
+        applied_deltas = [d for d, in txn.fetchall()]
+        return current_version, applied_deltas, upgraded
 
     return None
 
@@ -417,7 +483,19 @@ def prepare_sqlite3_database(db_conn):
 
             if row and row[0]:
                 db_conn.execute(
-                    "INSERT OR REPLACE INTO schema_version (version, upgraded)"
+                    "REPLACE INTO schema_version (version, upgraded)"
                     " VALUES (?,?)",
                     (row[0], False)
                 )
+
+
+def are_all_users_on_domain(txn, database_engine, domain):
+    sql = database_engine.convert_param_style(
+        "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
+    )
+    pat = "%:" + domain
+    txn.execute(sql, (pat,))
+    num_not_matching = txn.fetchall()[0][0]
+    if num_not_matching == 0:
+        return True
+    return False