diff options
Diffstat (limited to 'synapse/storage/__init__.py')
-rw-r--r-- | synapse/storage/__init__.py | 164 |
1 files changed, 116 insertions, 48 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 4b16f445d6..30cba47717 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -51,6 +51,8 @@ import logging import os import re +import threading + logger = logging.getLogger(__name__) @@ -89,6 +91,9 @@ class DataStore(RoomMemberStore, RoomStore, self.min_token_deferred = self._get_min_token() self.min_token = None + self._next_stream_id_lock = threading.Lock() + self._next_stream_id = int(hs.get_clock().time_msec()) * 1000 + @defer.inlineCallbacks @log_function def persist_event(self, event, context, backfilled=False, @@ -172,7 +177,6 @@ class DataStore(RoomMemberStore, RoomStore, "type": s.type, "state_key": s.state_key, }, - or_replace=True, ) if event.is_state() and is_new_state: @@ -186,7 +190,6 @@ class DataStore(RoomMemberStore, RoomStore, "type": event.type, "state_key": event.state_key, }, - or_replace=True, ) for prev_state_id, _ in event.prev_state: @@ -285,7 +288,6 @@ class DataStore(RoomMemberStore, RoomStore, "internal_metadata": metadata_json.decode("UTF-8"), "json": encode_canonical_json(event_dict).decode("UTF-8"), }, - or_replace=True, ) content = encode_canonical_json( @@ -303,8 +305,9 @@ class DataStore(RoomMemberStore, RoomStore, "depth": event.depth, } - if stream_ordering is not None: - vals["stream_ordering"] = stream_ordering + if stream_ordering is None: + stream_ordering = self.get_next_stream_id() + unrec = { k: v @@ -322,21 +325,18 @@ class DataStore(RoomMemberStore, RoomStore, unrec ).decode("UTF-8") - try: - self._simple_insert_txn( - txn, - "events", - vals, - or_replace=(not outlier), - or_ignore=bool(outlier), - ) - except: - logger.warn( - "Failed to persist, probably duplicate: %s", - event.event_id, - exc_info=True, - ) - raise _RollbackButIsFineException("_persist_event") + sql = ( + "INSERT INTO events" + " (stream_ordering, topological_ordering, event_id, type," + " room_id, content, processed, outlier, depth)" + " VALUES (%s,?,?,?,?,?,?,?,?)" + ) % (stream_ordering,) + + txn.execute( + sql, + (event.depth, event.event_id, event.type, event.room_id, + content, True, outlier, event.depth) + ) if context.rejected: self._store_rejections_txn(txn, event.event_id, context.rejected) @@ -357,7 +357,6 @@ class DataStore(RoomMemberStore, RoomStore, txn, "state_events", vals, - or_replace=True, ) if is_new_state and not context.rejected: @@ -370,7 +369,6 @@ class DataStore(RoomMemberStore, RoomStore, "type": event.type, "state_key": event.state_key, }, - or_replace=True, ) for e_id, h in event.prev_state: @@ -383,7 +381,6 @@ class DataStore(RoomMemberStore, RoomStore, "room_id": event.room_id, "is_state": 1, }, - or_ignore=True, ) for hash_alg, hash_base64 in event.hashes.items(): @@ -408,7 +405,6 @@ class DataStore(RoomMemberStore, RoomStore, "room_id": event.room_id, "auth_id": auth_id, }, - or_ignore=True, ) (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) @@ -420,8 +416,7 @@ class DataStore(RoomMemberStore, RoomStore, # invalidate the cache for the redacted event self._get_event_cache.pop(event.redacts) txn.execute( - "INSERT OR IGNORE INTO redactions " - "(event_id, redacts) VALUES (?,?)", + "INSERT INTO redactions (event_id, redacts) VALUES (?,?)", (event.event_id, event.redacts) ) @@ -515,7 +510,8 @@ class DataStore(RoomMemberStore, RoomStore, "ip": ip, "user_agent": user_agent, "last_seen": int(self._clock.time_msec()), - } + }, + or_replace=True, ) def get_user_ip_and_agents(self, user): @@ -559,6 +555,12 @@ class DataStore(RoomMemberStore, RoomStore, "have_events", f, ) + def get_next_stream_id(self): + with self._next_stream_id_lock: + i = self._next_stream_id + self._next_stream_id += 1 + return i + def read_schema(path): """ Read the named database schema. @@ -594,7 +596,7 @@ def prepare_database(db_conn): else: _setup_new_database(cur) - cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) cur.close() db_conn.commit() @@ -657,19 +659,17 @@ def _setup_new_database(cur): directory_entries = os.listdir(sql_dir) - sql_script = "BEGIN TRANSACTION;\n" for filename in fnmatch.filter(directory_entries, "*.sql"): sql_loc = os.path.join(sql_dir, filename) logger.debug("Applying schema %s", sql_loc) - sql_script += read_schema(sql_loc) - sql_script += "\n" - sql_script += "COMMIT TRANSACTION;" - cur.executescript(sql_script) + executescript(cur, sql_loc) cur.execute( - "INSERT OR REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (max_current_ver, False) + _convert_param_style( + "REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)" + ), + (max_current_ver, False,) ) _upgrade_existing_database( @@ -737,6 +737,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, if not upgraded: start_ver += 1 + logger.debug("applied_delta_files: %s", applied_delta_files) + for v in range(start_ver, SCHEMA_VERSION + 1): logger.debug("Upgrading schema to v%d", v) @@ -753,6 +755,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, directory_entries.sort() for file_name in directory_entries: relative_path = os.path.join(str(v), file_name) + logger.debug("Found file: %s", relative_path) if relative_path in applied_delta_files: continue @@ -774,9 +777,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, module.run_upgrade(cur) elif ext == ".sql": # A plain old .sql file, just read and execute it - delta_schema = read_schema(absolute_path) logger.debug("Applying schema %s", relative_path) - cur.executescript(delta_schema) + executescript(cur, absolute_path) else: # Not a valid delta file. logger.warn( @@ -788,24 +790,85 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, # Mark as done. cur.execute( - "INSERT INTO applied_schema_deltas (version, file)" - " VALUES (?,?)", + _convert_param_style( + "INSERT INTO applied_schema_deltas (version, file)" + " VALUES (?,?)" + ), (v, relative_path) ) cur.execute( - "INSERT OR REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", + _convert_param_style( + "REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)" + ), (v, True) ) +def _convert_param_style(sql): + return sql.replace("?", "%s") + + +def get_statements(f): + statement_buffer = "" + in_comment = False # If we're in a /* ... */ style comment + + for line in f: + line = line.strip() + + if in_comment: + # Check if this line contains an end to the comment + comments = line.split("*/", 1) + if len(comments) == 1: + continue + line = comments[1] + in_comment = False + + # Remove inline block comments + line = re.sub(r"/\*.*\*/", " ", line) + + # Does this line start a comment? + comments = line.split("/*", 1) + if len(comments) > 1: + line = comments[0] + in_comment = True + + # Deal with line comments + line = line.split("--", 1)[0] + line = line.split("//", 1)[0] + + # Find *all* semicolons. We need to treat first and last entry + # specially. + statements = line.split(";") + + # We must prepend statement_buffer to the first statement + first_statement = "%s %s" % ( + statement_buffer.strip(), + statements[0].strip() + ) + statements[0] = first_statement + + # Every entry, except the last, is a full statement + for statement in statements[:-1]: + yield statement.strip() + + # The last entry did *not* end in a semicolon, so we store it for the + # next semicolon we find + statement_buffer = statements[-1].strip() + + +def executescript(txn, schema_path): + with open(schema_path, 'r') as f: + for statement in get_statements(f): + txn.execute(statement) + + def _get_or_create_schema_state(txn): schema_path = os.path.join( dir_path, "schema", "schema_version.sql", ) - create_schema = read_schema(schema_path) - txn.executescript(create_schema) + executescript(txn, schema_path) txn.execute("SELECT version, upgraded FROM schema_version") row = txn.fetchone() @@ -814,10 +877,13 @@ def _get_or_create_schema_state(txn): if current_version: txn.execute( - "SELECT file FROM applied_schema_deltas WHERE version >= ?", + _convert_param_style( + "SELECT file FROM applied_schema_deltas WHERE version >= ?" + ), (current_version,) ) - return current_version, txn.fetchall(), upgraded + applied_deltas = [d for d, in txn.fetchall()] + return current_version, applied_deltas, upgraded return None @@ -849,7 +915,9 @@ def prepare_sqlite3_database(db_conn): if row and row[0]: db_conn.execute( - "INSERT OR REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", + _convert_param_style( + "REPLACE INTO schema_version (version, upgraded)" + " VALUES (?,?)" + ), (row[0], False) ) |