diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 4b16f445d6..30cba47717 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -51,6 +51,8 @@ import logging
import os
import re
+import threading
+
logger = logging.getLogger(__name__)
@@ -89,6 +91,9 @@ class DataStore(RoomMemberStore, RoomStore,
self.min_token_deferred = self._get_min_token()
self.min_token = None
+ self._next_stream_id_lock = threading.Lock()
+ self._next_stream_id = int(hs.get_clock().time_msec()) * 1000
+
@defer.inlineCallbacks
@log_function
def persist_event(self, event, context, backfilled=False,
@@ -172,7 +177,6 @@ class DataStore(RoomMemberStore, RoomStore,
"type": s.type,
"state_key": s.state_key,
},
- or_replace=True,
)
if event.is_state() and is_new_state:
@@ -186,7 +190,6 @@ class DataStore(RoomMemberStore, RoomStore,
"type": event.type,
"state_key": event.state_key,
},
- or_replace=True,
)
for prev_state_id, _ in event.prev_state:
@@ -285,7 +288,6 @@ class DataStore(RoomMemberStore, RoomStore,
"internal_metadata": metadata_json.decode("UTF-8"),
"json": encode_canonical_json(event_dict).decode("UTF-8"),
},
- or_replace=True,
)
content = encode_canonical_json(
@@ -303,8 +305,9 @@ class DataStore(RoomMemberStore, RoomStore,
"depth": event.depth,
}
- if stream_ordering is not None:
- vals["stream_ordering"] = stream_ordering
+ if stream_ordering is None:
+ stream_ordering = self.get_next_stream_id()
+
unrec = {
k: v
@@ -322,21 +325,18 @@ class DataStore(RoomMemberStore, RoomStore,
unrec
).decode("UTF-8")
- try:
- self._simple_insert_txn(
- txn,
- "events",
- vals,
- or_replace=(not outlier),
- or_ignore=bool(outlier),
- )
- except:
- logger.warn(
- "Failed to persist, probably duplicate: %s",
- event.event_id,
- exc_info=True,
- )
- raise _RollbackButIsFineException("_persist_event")
+ sql = (
+ "INSERT INTO events"
+ " (stream_ordering, topological_ordering, event_id, type,"
+ " room_id, content, processed, outlier, depth)"
+ " VALUES (%s,?,?,?,?,?,?,?,?)"
+ ) % (stream_ordering,)
+
+ txn.execute(
+ sql,
+ (event.depth, event.event_id, event.type, event.room_id,
+ content, True, outlier, event.depth)
+ )
if context.rejected:
self._store_rejections_txn(txn, event.event_id, context.rejected)
@@ -357,7 +357,6 @@ class DataStore(RoomMemberStore, RoomStore,
txn,
"state_events",
vals,
- or_replace=True,
)
if is_new_state and not context.rejected:
@@ -370,7 +369,6 @@ class DataStore(RoomMemberStore, RoomStore,
"type": event.type,
"state_key": event.state_key,
},
- or_replace=True,
)
for e_id, h in event.prev_state:
@@ -383,7 +381,6 @@ class DataStore(RoomMemberStore, RoomStore,
"room_id": event.room_id,
"is_state": 1,
},
- or_ignore=True,
)
for hash_alg, hash_base64 in event.hashes.items():
@@ -408,7 +405,6 @@ class DataStore(RoomMemberStore, RoomStore,
"room_id": event.room_id,
"auth_id": auth_id,
},
- or_ignore=True,
)
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
@@ -420,8 +416,7 @@ class DataStore(RoomMemberStore, RoomStore,
# invalidate the cache for the redacted event
self._get_event_cache.pop(event.redacts)
txn.execute(
- "INSERT OR IGNORE INTO redactions "
- "(event_id, redacts) VALUES (?,?)",
+ "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
(event.event_id, event.redacts)
)
@@ -515,7 +510,8 @@ class DataStore(RoomMemberStore, RoomStore,
"ip": ip,
"user_agent": user_agent,
"last_seen": int(self._clock.time_msec()),
- }
+ },
+ or_replace=True,
)
def get_user_ip_and_agents(self, user):
@@ -559,6 +555,12 @@ class DataStore(RoomMemberStore, RoomStore,
"have_events", f,
)
+ def get_next_stream_id(self):
+ with self._next_stream_id_lock:
+ i = self._next_stream_id
+ self._next_stream_id += 1
+ return i
+
def read_schema(path):
""" Read the named database schema.
@@ -594,7 +596,7 @@ def prepare_database(db_conn):
else:
_setup_new_database(cur)
- cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
+ # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
cur.close()
db_conn.commit()
@@ -657,19 +659,17 @@ def _setup_new_database(cur):
directory_entries = os.listdir(sql_dir)
- sql_script = "BEGIN TRANSACTION;\n"
for filename in fnmatch.filter(directory_entries, "*.sql"):
sql_loc = os.path.join(sql_dir, filename)
logger.debug("Applying schema %s", sql_loc)
- sql_script += read_schema(sql_loc)
- sql_script += "\n"
- sql_script += "COMMIT TRANSACTION;"
- cur.executescript(sql_script)
+ executescript(cur, sql_loc)
cur.execute(
- "INSERT OR REPLACE INTO schema_version (version, upgraded)"
- " VALUES (?,?)",
- (max_current_ver, False)
+ _convert_param_style(
+ "REPLACE INTO schema_version (version, upgraded)"
+ " VALUES (?,?)"
+ ),
+ (max_current_ver, False,)
)
_upgrade_existing_database(
@@ -737,6 +737,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
if not upgraded:
start_ver += 1
+ logger.debug("applied_delta_files: %s", applied_delta_files)
+
for v in range(start_ver, SCHEMA_VERSION + 1):
logger.debug("Upgrading schema to v%d", v)
@@ -753,6 +755,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
directory_entries.sort()
for file_name in directory_entries:
relative_path = os.path.join(str(v), file_name)
+ logger.debug("Found file: %s", relative_path)
if relative_path in applied_delta_files:
continue
@@ -774,9 +777,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module.run_upgrade(cur)
elif ext == ".sql":
# A plain old .sql file, just read and execute it
- delta_schema = read_schema(absolute_path)
logger.debug("Applying schema %s", relative_path)
- cur.executescript(delta_schema)
+ executescript(cur, absolute_path)
else:
# Not a valid delta file.
logger.warn(
@@ -788,24 +790,85 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
# Mark as done.
cur.execute(
- "INSERT INTO applied_schema_deltas (version, file)"
- " VALUES (?,?)",
+ _convert_param_style(
+ "INSERT INTO applied_schema_deltas (version, file)"
+ " VALUES (?,?)"
+ ),
(v, relative_path)
)
cur.execute(
- "INSERT OR REPLACE INTO schema_version (version, upgraded)"
- " VALUES (?,?)",
+ _convert_param_style(
+ "REPLACE INTO schema_version (version, upgraded)"
+ " VALUES (?,?)"
+ ),
(v, True)
)
+def _convert_param_style(sql):
+ return sql.replace("?", "%s")
+
+
+def get_statements(f):
+ statement_buffer = ""
+ in_comment = False # If we're in a /* ... */ style comment
+
+ for line in f:
+ line = line.strip()
+
+ if in_comment:
+ # Check if this line contains an end to the comment
+ comments = line.split("*/", 1)
+ if len(comments) == 1:
+ continue
+ line = comments[1]
+ in_comment = False
+
+ # Remove inline block comments
+ line = re.sub(r"/\*.*\*/", " ", line)
+
+ # Does this line start a comment?
+ comments = line.split("/*", 1)
+ if len(comments) > 1:
+ line = comments[0]
+ in_comment = True
+
+ # Deal with line comments
+ line = line.split("--", 1)[0]
+ line = line.split("//", 1)[0]
+
+ # Find *all* semicolons. We need to treat first and last entry
+ # specially.
+ statements = line.split(";")
+
+ # We must prepend statement_buffer to the first statement
+ first_statement = "%s %s" % (
+ statement_buffer.strip(),
+ statements[0].strip()
+ )
+ statements[0] = first_statement
+
+ # Every entry, except the last, is a full statement
+ for statement in statements[:-1]:
+ yield statement.strip()
+
+ # The last entry did *not* end in a semicolon, so we store it for the
+ # next semicolon we find
+ statement_buffer = statements[-1].strip()
+
+
+def executescript(txn, schema_path):
+ with open(schema_path, 'r') as f:
+ for statement in get_statements(f):
+ txn.execute(statement)
+
+
def _get_or_create_schema_state(txn):
schema_path = os.path.join(
dir_path, "schema", "schema_version.sql",
)
- create_schema = read_schema(schema_path)
- txn.executescript(create_schema)
+ executescript(txn, schema_path)
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
@@ -814,10 +877,13 @@ def _get_or_create_schema_state(txn):
if current_version:
txn.execute(
- "SELECT file FROM applied_schema_deltas WHERE version >= ?",
+ _convert_param_style(
+ "SELECT file FROM applied_schema_deltas WHERE version >= ?"
+ ),
(current_version,)
)
- return current_version, txn.fetchall(), upgraded
+ applied_deltas = [d for d, in txn.fetchall()]
+ return current_version, applied_deltas, upgraded
return None
@@ -849,7 +915,9 @@ def prepare_sqlite3_database(db_conn):
if row and row[0]:
db_conn.execute(
- "INSERT OR REPLACE INTO schema_version (version, upgraded)"
- " VALUES (?,?)",
+ _convert_param_style(
+ "REPLACE INTO schema_version (version, upgraded)"
+ " VALUES (?,?)"
+ ),
(row[0], False)
)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2979a83524..24ff872dad 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -102,6 +102,10 @@ def cached(max_entries=1000):
return wrap
+def _convert_param_style(sql):
+ return sql.replace("?", "%s")
+
+
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
@@ -122,6 +126,8 @@ class LoggingTransaction(object):
# TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
+ sql = _convert_param_style(sql)
+
try:
if args and args[0]:
values = args[0]
@@ -305,11 +311,11 @@ class SQLBaseStore(object):
The result of decoder(results)
"""
def interaction(txn):
- cursor = txn.execute(query, args)
+ txn.execute(query, args)
if decoder:
- return decoder(cursor)
+ return decoder(txn)
else:
- return cursor.fetchall()
+ return txn.fetchall()
return self.runInteraction(desc, interaction)
@@ -337,8 +343,7 @@ class SQLBaseStore(object):
def _simple_insert_txn(self, txn, table, values, or_replace=False,
or_ignore=False):
sql = "%s INTO %s (%s) VALUES(%s)" % (
- ("INSERT OR REPLACE" if or_replace else
- "INSERT OR IGNORE" if or_ignore else "INSERT"),
+ ("REPLACE" if or_replace else "INSERT"),
table,
", ".join(k for k in values),
", ".join("?" for k in values)
@@ -448,8 +453,7 @@ class SQLBaseStore(object):
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
sql = (
- "SELECT %(retcol)s FROM %(table)s WHERE %(where)s "
- "ORDER BY rowid asc"
+ "SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
) % {
"retcol": retcol,
"table": table,
@@ -505,14 +509,14 @@ class SQLBaseStore(object):
retcols : list of strings giving the names of the columns to return
"""
if keyvalues:
- sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
+ sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
txn.execute(sql, keyvalues.values())
else:
- sql = "SELECT %s FROM %s ORDER BY rowid asc" % (
+ sql = "SELECT %s FROM %s" % (
", ".join(retcols),
table
)
@@ -546,7 +550,7 @@ class SQLBaseStore(object):
retcols=None, allow_none=False):
""" Combined SELECT then UPDATE."""
if retcols:
- select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
+ select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k) for k in keyvalues)
@@ -580,8 +584,8 @@ class SQLBaseStore(object):
updatevalues.values() + keyvalues.values()
)
- if txn.rowcount == 0:
- raise StoreError(404, "No row found")
+ # if txn.rowcount == 0:
+ # raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
@@ -802,7 +806,7 @@ class Table(object):
_select_where_clause = "SELECT %s FROM %s WHERE %s"
_select_clause = "SELECT %s FROM %s"
- _insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)"
+ _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)"
@classmethod
def select_statement(cls, where_clause=None):
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 850676ce6c..375265d666 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -147,11 +147,11 @@ class ApplicationServiceStore(SQLBaseStore):
return True
def _get_as_id_txn(self, txn, token):
- cursor = txn.execute(
+ txn.execute(
"SELECT id FROM application_services WHERE token=?",
(token,)
)
- res = cursor.fetchone()
+ res = txn.fetchone()
if res:
return res[0]
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 68b7d59693..0c2adffbbe 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -111,12 +111,12 @@ class DirectoryStore(SQLBaseStore):
)
def _delete_room_alias_txn(self, txn, room_alias):
- cursor = txn.execute(
+ txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),)
)
- res = cursor.fetchone()
+ res = txn.fetchone()
if res:
room_id = res[0]
else:
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 2deda8ac50..5d66b2f24c 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -242,7 +242,6 @@ class EventFederationStore(SQLBaseStore):
"room_id": room_id,
"min_depth": depth,
},
- or_replace=True,
)
def _handle_prev_events(self, txn, outlier, event_id, prev_events,
@@ -262,7 +261,6 @@ class EventFederationStore(SQLBaseStore):
"room_id": room_id,
"is_state": 0,
},
- or_ignore=True,
)
# Update the extremities table if this is not an outlier.
@@ -281,19 +279,19 @@ class EventFederationStore(SQLBaseStore):
# We only insert as a forward extremity the new event if there are
# no other events that reference it as a prev event
query = (
- "INSERT OR IGNORE INTO %(table)s (event_id, room_id) "
- "SELECT ?, ? WHERE NOT EXISTS ("
- "SELECT 1 FROM %(event_edges)s WHERE "
- "prev_event_id = ? "
- ")"
- ) % {
- "table": "event_forward_extremities",
- "event_edges": "event_edges",
- }
+ "SELECT 1 FROM event_edges WHERE prev_event_id = ?"
+ )
- logger.debug("query: %s", query)
+ txn.execute(query, (event_id,))
+
+ if not txn.fetchone():
+ query = (
+ "INSERT INTO event_forward_extremities"
+ " (event_id, room_id)"
+ " VALUES (?, ?)"
+ )
- txn.execute(query, (event_id, room_id, event_id))
+ txn.execute(query, (event_id, room_id))
# Insert all the prev_events as a backwards thing, they'll get
# deleted in a second if they're incorrect anyway.
@@ -306,7 +304,6 @@ class EventFederationStore(SQLBaseStore):
"event_id": e_id,
"room_id": room_id,
},
- or_ignore=True,
)
# Also delete from the backwards extremities table all ones that
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 1dcd34723b..0084d67e5b 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -45,7 +45,6 @@ class PresenceStore(SQLBaseStore):
updatevalues={"state": new_state["state"],
"status_msg": new_state["status_msg"],
"mtime": self._clock.time_msec()},
- retcols=["state"],
)
def allow_presence_visible(self, observed_localpart, observer_userid):
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index d769db2c78..27a0716323 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -153,7 +153,7 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority))
# now insert the new rule
- sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
+ sql = "INSERT INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
@@ -182,7 +182,7 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio
- sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
+ sql = "INSERT INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index adc8fc0794..344dd3aaac 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -39,14 +39,10 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if there was a problem adding this.
"""
- row = yield self._simple_select_one("users", {"name": user_id}, ["id"])
- if not row:
- raise StoreError(400, "Bad user ID supplied.")
- row_id = row["id"]
yield self._simple_insert(
"access_tokens",
{
- "user_id": row_id,
+ "user_id": user_id,
"token": token
}
)
@@ -82,7 +78,7 @@ class RegistrationStore(SQLBaseStore):
# it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID
txn.execute("INSERT INTO access_tokens(user_id, token) " +
- "VALUES (?,?)", [txn.lastrowid, token])
+ "VALUES (?,?)", [user_id, token])
def get_user_by_id(self, user_id):
query = ("SELECT users.name, users.password_hash FROM users"
@@ -124,12 +120,12 @@ class RegistrationStore(SQLBaseStore):
"SELECT users.name, users.admin,"
" access_tokens.device_id, access_tokens.id as token_id"
" FROM users"
- " INNER JOIN access_tokens on users.id = access_tokens.user_id"
+ " INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
)
- cursor = txn.execute(sql, (token,))
- rows = self.cursor_to_dict(cursor)
+ txn.execute(sql, (token,))
+ rows = self.cursor_to_dict(txn)
if rows:
return rows[0]
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 549c9af393..3c23f29215 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -114,9 +114,9 @@ class RoomStore(SQLBaseStore):
"name": name_subquery,
}
- c = txn.execute(sql, (is_public,))
+ txn.execute(sql, (is_public,))
- return c.fetchall()
+ return txn.fetchall()
rows = yield self.runInteraction(
"get_rooms", f
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 65ffb4627f..e8ede14cd7 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -68,7 +68,7 @@ class RoomMemberStore(SQLBaseStore):
# Update room hosts table
if event.membership == Membership.JOIN:
sql = (
- "INSERT OR IGNORE INTO room_hosts (room_id, host) "
+ "REPLACE INTO room_hosts (room_id, host) "
"VALUES (?, ?)"
)
txn.execute(sql, (event.room_id, domain))
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 456e4bd45d..888837cd1e 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -15,6 +15,8 @@
from ._base import SQLBaseStore
+from synapse.util.stringutils import random_string
+
import logging
logger = logging.getLogger(__name__)
@@ -89,14 +91,15 @@ class StateStore(SQLBaseStore):
state_group = context.state_group
if not state_group:
+ group = _make_group_id(self._clock)
state_group = self._simple_insert_txn(
txn,
table="state_groups",
values={
+ "id": group,
"room_id": event.room_id,
"event_id": event.event_id,
},
- or_ignore=True,
)
for state in state_events.values():
@@ -110,7 +113,6 @@ class StateStore(SQLBaseStore):
"state_key": state.state_key,
"event_id": state.event_id,
},
- or_ignore=True,
)
self._simple_insert_txn(
@@ -122,3 +124,7 @@ class StateStore(SQLBaseStore):
},
or_replace=True,
)
+
+
+def _make_group_id(clock):
+ return str(int(clock.time_msec())) + random_string(5)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 09bc522210..64adb0c7fa 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -110,7 +110,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
if self.topological is None:
return "(%d < %s)" % (self.stream, "stream_ordering")
else:
- return "(%d < %s OR (%d == %s AND %d < %s))" % (
+ return "(%d < %s OR (%d = %s AND %d < %s))" % (
self.topological, "topological_ordering",
self.topological, "topological_ordering",
self.stream, "stream_ordering",
@@ -120,7 +120,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
if self.topological is None:
return "(%d >= %s)" % (self.stream, "stream_ordering")
else:
- return "(%d > %s OR (%d == %s AND %d >= %s))" % (
+ return "(%d > %s OR (%d = %s AND %d >= %s))" % (
self.topological, "topological_ordering",
self.topological, "topological_ordering",
self.stream, "stream_ordering",
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 0b8a3b7a07..b5ed5453d8 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -121,8 +121,8 @@ class TransactionStore(SQLBaseStore):
SentTransactions.select_statement("destination = ?"),
)
- results = txn.execute(query, (destination,))
- results = SentTransactions.decode_results(results)
+ txn.execute(query, (destination,))
+ results = SentTransactions.decode_results(txn)
prev_txns = [r.transaction_id for r in results]
@@ -266,7 +266,7 @@ class TransactionStore(SQLBaseStore):
retry_last_ts, retry_interval):
query = (
- "INSERT OR REPLACE INTO %s "
+ "REPLACE INTO %s "
"(destination, retry_last_ts, retry_interval) "
"VALUES (?, ?, ?) "
) % DestinationsTable.table_name
|