summary refs log tree commit diff
path: root/synapse/storage/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/__init__.py')
-rw-r--r--synapse/storage/__init__.py164
1 files changed, 116 insertions, 48 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 4b16f445d6..30cba47717 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -51,6 +51,8 @@ import logging
 import os
 import re
 
+import threading
+
 
 logger = logging.getLogger(__name__)
 
@@ -89,6 +91,9 @@ class DataStore(RoomMemberStore, RoomStore,
         self.min_token_deferred = self._get_min_token()
         self.min_token = None
 
+        self._next_stream_id_lock = threading.Lock()
+        self._next_stream_id = int(hs.get_clock().time_msec()) * 1000
+
     @defer.inlineCallbacks
     @log_function
     def persist_event(self, event, context, backfilled=False,
@@ -172,7 +177,6 @@ class DataStore(RoomMemberStore, RoomStore,
                         "type": s.type,
                         "state_key": s.state_key,
                     },
-                    or_replace=True,
                 )
 
         if event.is_state() and is_new_state:
@@ -186,7 +190,6 @@ class DataStore(RoomMemberStore, RoomStore,
                         "type": event.type,
                         "state_key": event.state_key,
                     },
-                    or_replace=True,
                 )
 
                 for prev_state_id, _ in event.prev_state:
@@ -285,7 +288,6 @@ class DataStore(RoomMemberStore, RoomStore,
                 "internal_metadata": metadata_json.decode("UTF-8"),
                 "json": encode_canonical_json(event_dict).decode("UTF-8"),
             },
-            or_replace=True,
         )
 
         content = encode_canonical_json(
@@ -303,8 +305,9 @@ class DataStore(RoomMemberStore, RoomStore,
             "depth": event.depth,
         }
 
-        if stream_ordering is not None:
-            vals["stream_ordering"] = stream_ordering
+        if stream_ordering is None:
+            stream_ordering = self.get_next_stream_id()
+
 
         unrec = {
             k: v
@@ -322,21 +325,18 @@ class DataStore(RoomMemberStore, RoomStore,
             unrec
         ).decode("UTF-8")
 
-        try:
-            self._simple_insert_txn(
-                txn,
-                "events",
-                vals,
-                or_replace=(not outlier),
-                or_ignore=bool(outlier),
-            )
-        except:
-            logger.warn(
-                "Failed to persist, probably duplicate: %s",
-                event.event_id,
-                exc_info=True,
-            )
-            raise _RollbackButIsFineException("_persist_event")
+        sql = (
+            "INSERT INTO events"
+            " (stream_ordering, topological_ordering, event_id, type,"
+            " room_id, content, processed, outlier, depth)"
+            " VALUES (%s,?,?,?,?,?,?,?,?)"
+        ) % (stream_ordering,)
+
+        txn.execute(
+            sql,
+            (event.depth, event.event_id, event.type, event.room_id,
+             content, True, outlier, event.depth)
+        )
 
         if context.rejected:
             self._store_rejections_txn(txn, event.event_id, context.rejected)
@@ -357,7 +357,6 @@ class DataStore(RoomMemberStore, RoomStore,
                 txn,
                 "state_events",
                 vals,
-                or_replace=True,
             )
 
             if is_new_state and not context.rejected:
@@ -370,7 +369,6 @@ class DataStore(RoomMemberStore, RoomStore,
                         "type": event.type,
                         "state_key": event.state_key,
                     },
-                    or_replace=True,
                 )
 
             for e_id, h in event.prev_state:
@@ -383,7 +381,6 @@ class DataStore(RoomMemberStore, RoomStore,
                         "room_id": event.room_id,
                         "is_state": 1,
                     },
-                    or_ignore=True,
                 )
 
         for hash_alg, hash_base64 in event.hashes.items():
@@ -408,7 +405,6 @@ class DataStore(RoomMemberStore, RoomStore,
                     "room_id": event.room_id,
                     "auth_id": auth_id,
                 },
-                or_ignore=True,
             )
 
         (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
@@ -420,8 +416,7 @@ class DataStore(RoomMemberStore, RoomStore,
         # invalidate the cache for the redacted event
         self._get_event_cache.pop(event.redacts)
         txn.execute(
-            "INSERT OR IGNORE INTO redactions "
-            "(event_id, redacts) VALUES (?,?)",
+            "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
             (event.event_id, event.redacts)
         )
 
@@ -515,7 +510,8 @@ class DataStore(RoomMemberStore, RoomStore,
                 "ip": ip,
                 "user_agent": user_agent,
                 "last_seen": int(self._clock.time_msec()),
-            }
+            },
+            or_replace=True,
         )
 
     def get_user_ip_and_agents(self, user):
@@ -559,6 +555,12 @@ class DataStore(RoomMemberStore, RoomStore,
             "have_events", f,
         )
 
+    def get_next_stream_id(self):
+        with self._next_stream_id_lock:
+            i = self._next_stream_id
+            self._next_stream_id += 1
+            return i
+
 
 def read_schema(path):
     """ Read the named database schema.
@@ -594,7 +596,7 @@ def prepare_database(db_conn):
         else:
             _setup_new_database(cur)
 
-        cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
+        # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
 
         cur.close()
         db_conn.commit()
@@ -657,19 +659,17 @@ def _setup_new_database(cur):
 
     directory_entries = os.listdir(sql_dir)
 
-    sql_script = "BEGIN TRANSACTION;\n"
     for filename in fnmatch.filter(directory_entries, "*.sql"):
         sql_loc = os.path.join(sql_dir, filename)
         logger.debug("Applying schema %s", sql_loc)
-        sql_script += read_schema(sql_loc)
-        sql_script += "\n"
-    sql_script += "COMMIT TRANSACTION;"
-    cur.executescript(sql_script)
+        executescript(cur, sql_loc)
 
     cur.execute(
-        "INSERT OR REPLACE INTO schema_version (version, upgraded)"
-        " VALUES (?,?)",
-        (max_current_ver, False)
+        _convert_param_style(
+            "REPLACE INTO schema_version (version, upgraded)"
+            " VALUES (?,?)"
+        ),
+        (max_current_ver, False,)
     )
 
     _upgrade_existing_database(
@@ -737,6 +737,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
     if not upgraded:
         start_ver += 1
 
+    logger.debug("applied_delta_files: %s", applied_delta_files)
+
     for v in range(start_ver, SCHEMA_VERSION + 1):
         logger.debug("Upgrading schema to v%d", v)
 
@@ -753,6 +755,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
         directory_entries.sort()
         for file_name in directory_entries:
             relative_path = os.path.join(str(v), file_name)
+            logger.debug("Found file: %s", relative_path)
             if relative_path in applied_delta_files:
                 continue
 
@@ -774,9 +777,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
                 module.run_upgrade(cur)
             elif ext == ".sql":
                 # A plain old .sql file, just read and execute it
-                delta_schema = read_schema(absolute_path)
                 logger.debug("Applying schema %s", relative_path)
-                cur.executescript(delta_schema)
+                executescript(cur, absolute_path)
             else:
                 # Not a valid delta file.
                 logger.warn(
@@ -788,24 +790,85 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
 
             # Mark as done.
             cur.execute(
-                "INSERT INTO applied_schema_deltas (version, file)"
-                " VALUES (?,?)",
+                _convert_param_style(
+                    "INSERT INTO applied_schema_deltas (version, file)"
+                    " VALUES (?,?)"
+                ),
                 (v, relative_path)
             )
 
             cur.execute(
-                "INSERT OR REPLACE INTO schema_version (version, upgraded)"
-                " VALUES (?,?)",
+                _convert_param_style(
+                    "REPLACE INTO schema_version (version, upgraded)"
+                    " VALUES (?,?)"
+                ),
                 (v, True)
             )
 
 
+def _convert_param_style(sql):
+    return sql.replace("?", "%s")
+
+
+def get_statements(f):
+    statement_buffer = ""
+    in_comment = False  # If we're in a /* ... */ style comment
+
+    for line in f:
+        line = line.strip()
+
+        if in_comment:
+            # Check if this line contains an end to the comment
+            comments = line.split("*/", 1)
+            if len(comments) == 1:
+                continue
+            line = comments[1]
+            in_comment = False
+
+        # Remove inline block comments
+        line = re.sub(r"/\*.*\*/", " ", line)
+
+        # Does this line start a comment?
+        comments = line.split("/*", 1)
+        if len(comments) > 1:
+            line = comments[0]
+            in_comment = True
+
+        # Deal with line comments
+        line = line.split("--", 1)[0]
+        line = line.split("//", 1)[0]
+
+        # Find *all* semicolons. We need to treat first and last entry
+        # specially.
+        statements = line.split(";")
+
+        # We must prepend statement_buffer to the first statement
+        first_statement = "%s %s" % (
+            statement_buffer.strip(),
+            statements[0].strip()
+        )
+        statements[0] = first_statement
+
+        # Every entry, except the last, is a full statement
+        for statement in statements[:-1]:
+            yield statement.strip()
+
+        # The last entry did *not* end in a semicolon, so we store it for the
+        # next semicolon we find
+        statement_buffer = statements[-1].strip()
+
+
+def executescript(txn, schema_path):
+    with open(schema_path, 'r') as f:
+        for statement in get_statements(f):
+            txn.execute(statement)
+
+
 def _get_or_create_schema_state(txn):
     schema_path = os.path.join(
         dir_path, "schema", "schema_version.sql",
     )
-    create_schema = read_schema(schema_path)
-    txn.executescript(create_schema)
+    executescript(txn, schema_path)
 
     txn.execute("SELECT version, upgraded FROM schema_version")
     row = txn.fetchone()
@@ -814,10 +877,13 @@ def _get_or_create_schema_state(txn):
 
     if current_version:
         txn.execute(
-            "SELECT file FROM applied_schema_deltas WHERE version >= ?",
+            _convert_param_style(
+                "SELECT file FROM applied_schema_deltas WHERE version >= ?"
+            ),
             (current_version,)
         )
-        return current_version, txn.fetchall(), upgraded
+        applied_deltas = [d for d, in txn.fetchall()]
+        return current_version, applied_deltas, upgraded
 
     return None
 
@@ -849,7 +915,9 @@ def prepare_sqlite3_database(db_conn):
 
             if row and row[0]:
                 db_conn.execute(
-                    "INSERT OR REPLACE INTO schema_version (version, upgraded)"
-                    " VALUES (?,?)",
+                    _convert_param_style(
+                        "REPLACE INTO schema_version (version, upgraded)"
+                        " VALUES (?,?)"
+                    ),
                     (row[0], False)
                 )