diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 65a86e9056..5d4be09a82 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -14,59 +14,72 @@
# limitations under the License.
import logging
-from twisted.internet import defer
-
from synapse.api.errors import StoreError
from synapse.api.events.utils import prune_event
from synapse.util.logutils import log_function
+from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
+from syutil.base64util import encode_base64
+
+from twisted.internet import defer
import collections
import copy
import json
+import sys
+import time
logger = logging.getLogger(__name__)
sql_logger = logging.getLogger("synapse.storage.SQL")
+transaction_logger = logging.getLogger("synapse.storage.txn")
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging to the .execute() method."""
- __slots__ = ["txn"]
+ __slots__ = ["txn", "name"]
- def __init__(self, txn):
+ def __init__(self, txn, name):
object.__setattr__(self, "txn", txn)
+ object.__setattr__(self, "name", name)
- def __getattribute__(self, name):
- if name == "execute":
- return object.__getattribute__(self, "execute")
-
- return getattr(object.__getattribute__(self, "txn"), name)
+ def __getattr__(self, name):
+ return getattr(self.txn, name)
def __setattr__(self, name, value):
- setattr(object.__getattribute__(self, "txn"), name, value)
+ setattr(self.txn, name, value)
def execute(self, sql, *args, **kwargs):
# TODO(paul): Maybe use 'info' and 'debug' for values?
- sql_logger.debug("[SQL] %s", sql)
+ sql_logger.debug("[SQL] {%s} %s", self.name, sql)
try:
if args and args[0]:
values = args[0]
- sql_logger.debug("[SQL values] " +
- ", ".join(("<%s>",) * len(values)), *values)
+ sql_logger.debug(
+ "[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)),
+ self.name,
+ *values
+ )
except:
# Don't let logging failures stop SQL from working
pass
- # TODO(paul): Here would be an excellent place to put some timing
- # measurements, and log (warning?) slow queries.
- return object.__getattribute__(self, "txn").execute(
- sql, *args, **kwargs
- )
+ start = time.clock() * 1000
+ try:
+ return self.txn.execute(
+ sql, *args, **kwargs
+ )
+ except:
+ logger.exception("[SQL FAIL] {%s}", self.name)
+ raise
+ finally:
+ end = time.clock() * 1000
+ sql_logger.debug("[SQL time] {%s} %f", self.name, end - start)
class SQLBaseStore(object):
+ _TXN_ID = 0
def __init__(self, hs):
self.hs = hs
@@ -74,12 +87,40 @@ class SQLBaseStore(object):
self.event_factory = hs.get_event_factory()
self._clock = hs.get_clock()
- def runInteraction(self, func, *args, **kwargs):
+ @defer.inlineCallbacks
+ def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
+ current_context = LoggingContext.current_context()
def inner_func(txn, *args, **kwargs):
- return func(LoggingTransaction(txn), *args, **kwargs)
-
- return self._db_pool.runInteraction(inner_func, *args, **kwargs)
+ with LoggingContext("runInteraction") as context:
+ current_context.copy_to(context)
+ start = time.clock() * 1000
+ txn_id = SQLBaseStore._TXN_ID
+
+ # We don't really need these to be unique, so lets stop it from
+ # growing really large.
+ self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+
+ name = "%s-%x" % (desc, txn_id, )
+
+ transaction_logger.debug("[TXN START] {%s}", name)
+ try:
+ return func(LoggingTransaction(txn, name), *args, **kwargs)
+ except:
+ logger.exception("[TXN FAIL] {%s}", name)
+ raise
+ finally:
+ end = time.clock() * 1000
+ transaction_logger.debug(
+ "[TXN END] {%s} %f",
+ name, end - start
+ )
+
+ with PreserveLoggingContext():
+ result = yield self._db_pool.runInteraction(
+ inner_func, *args, **kwargs
+ )
+ defer.returnValue(result)
def cursor_to_dict(self, cursor):
"""Converts a SQL cursor into an list of dicts.
@@ -113,7 +154,7 @@ class SQLBaseStore(object):
else:
return cursor.fetchall()
- return self.runInteraction(interaction)
+ return self.runInteraction("_execute", interaction)
def _execute_and_decode(self, query, *args):
return self._execute(self.cursor_to_dict, query, *args)
@@ -130,6 +171,7 @@ class SQLBaseStore(object):
or_replace : bool; if True performs an INSERT OR REPLACE
"""
return self.runInteraction(
+ "_simple_insert",
self._simple_insert_txn, table, values, or_replace=or_replace,
or_ignore=or_ignore,
)
@@ -146,7 +188,7 @@ class SQLBaseStore(object):
)
logger.debug(
- "[SQL] %s Args=%s Func=%s",
+ "[SQL] %s Args=%s",
sql, values.values(),
)
@@ -170,7 +212,6 @@ class SQLBaseStore(object):
table, keyvalues, retcols=retcols, allow_none=allow_none
)
- @defer.inlineCallbacks
def _simple_select_one_onecol(self, table, keyvalues, retcol,
allow_none=False):
"""Executes a SELECT query on the named table, which is expected to
@@ -181,19 +222,40 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the row with
retcol : string giving the name of the column to return
"""
- ret = yield self._simple_select_one(
+ return self.runInteraction(
+ "_simple_select_one_onecol",
+ self._simple_select_one_onecol_txn,
+ table, keyvalues, retcol, allow_none=allow_none,
+ )
+
+ def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
+ allow_none=False):
+ ret = self._simple_select_onecol_txn(
+ txn,
table=table,
keyvalues=keyvalues,
- retcols=[retcol],
- allow_none=allow_none
+ retcol=retcol,
)
if ret:
- defer.returnValue(ret[retcol])
+ return ret[0]
else:
- defer.returnValue(None)
+ if allow_none:
+ return None
+ else:
+ raise StoreError(404, "No row found")
+
+ def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
+ sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
+ "retcol": retcol,
+ "table": table,
+ "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
+ }
+
+ txn.execute(sql, keyvalues.values())
+
+ return [r[0] for r in txn.fetchall()]
- @defer.inlineCallbacks
def _simple_select_onecol(self, table, keyvalues, retcol):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@@ -206,25 +268,33 @@ class SQLBaseStore(object):
Returns:
Deferred: Results in a list
"""
- sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
- "retcol": retcol,
- "table": table,
- "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
- }
-
- def func(txn):
- txn.execute(sql, keyvalues.values())
- return txn.fetchall()
+ return self.runInteraction(
+ "_simple_select_onecol",
+ self._simple_select_onecol_txn,
+ table, keyvalues, retcol
+ )
- res = yield self.runInteraction(func)
+ def _simple_select_list(self, table, keyvalues, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
- defer.returnValue([r[0] for r in res])
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the rows with
+ retcols : list of strings giving the names of the columns to return
+ """
+ return self.runInteraction(
+ "_simple_select_list",
+ self._simple_select_list_txn,
+ table, keyvalues, retcols
+ )
- def _simple_select_list(self, table, keyvalues, retcols):
+ def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
+ txn : Transaction object
table : string giving the table name
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
@@ -232,14 +302,11 @@ class SQLBaseStore(object):
sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
- " AND ".join("%s = ?" % (k) for k in keyvalues)
+ " AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- def func(txn):
- txn.execute(sql, keyvalues.values())
- return self.cursor_to_dict(txn)
-
- return self.runInteraction(func)
+ txn.execute(sql, keyvalues.values())
+ return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
retcols=None):
@@ -307,7 +374,7 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched")
return ret
- return self.runInteraction(func)
+ return self.runInteraction("_simple_selectupdate_one", func)
def _simple_delete_one(self, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a
@@ -319,7 +386,7 @@ class SQLBaseStore(object):
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
- " AND ".join("%s = ?" % (k) for k in keyvalues)
+ " AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
def func(txn):
@@ -328,7 +395,25 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
- return self.runInteraction(func)
+ return self.runInteraction("_simple_delete_one", func)
+
+ def _simple_delete(self, table, keyvalues):
+ """Executes a DELETE query on the named table.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ """
+
+ return self.runInteraction("_simple_delete", self._simple_delete_txn)
+
+ def _simple_delete_txn(self, txn, table, keyvalues):
+ sql = "DELETE FROM %s WHERE %s" % (
+ table,
+ " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+ )
+
+ return txn.execute(sql, keyvalues.values())
def _simple_max_id(self, table):
"""Executes a SELECT query on the named table, expecting to return the
@@ -346,7 +431,7 @@ class SQLBaseStore(object):
return 0
return max_id
- return self.runInteraction(func)
+ return self.runInteraction("_simple_max_id", func)
def _parse_event_from_row(self, row_dict):
d = copy.deepcopy({k: v for k, v in row_dict.items()})
@@ -355,6 +440,10 @@ class SQLBaseStore(object):
d.pop("topological_ordering", None)
d.pop("processed", None)
d["origin_server_ts"] = d.pop("ts", 0)
+ replaces_state = d.pop("prev_state", None)
+
+ if replaces_state:
+ d["replaces_state"] = replaces_state
d.update(json.loads(row_dict["unrecognized_keys"]))
d["content"] = json.loads(d["content"])
@@ -369,23 +458,76 @@ class SQLBaseStore(object):
**d
)
+ def _get_events_txn(self, txn, event_ids):
+ # FIXME (erikj): This should be batched?
+
+ sql = "SELECT * FROM events WHERE event_id = ?"
+
+ event_rows = []
+ for e_id in event_ids:
+ c = txn.execute(sql, (e_id,))
+ event_rows.extend(self.cursor_to_dict(c))
+
+ return self._parse_events_txn(txn, event_rows)
+
def _parse_events(self, rows):
- return self.runInteraction(self._parse_events_txn, rows)
+ return self.runInteraction(
+ "_parse_events", self._parse_events_txn, rows
+ )
def _parse_events_txn(self, txn, rows):
events = [self._parse_event_from_row(r) for r in rows]
- sql = "SELECT * FROM events WHERE event_id = ?"
+ select_event_sql = "SELECT * FROM events WHERE event_id = ?"
+
+ for i, ev in enumerate(events):
+ signatures = self._get_event_signatures_txn(
+ txn, ev.event_id,
+ )
+
+ ev.signatures = {
+ n: {
+ k: encode_base64(v) for k, v in s.items()
+ }
+ for n, s in signatures.items()
+ }
+
+ hashes = self._get_event_content_hashes_txn(
+ txn, ev.event_id,
+ )
- for ev in events:
- if hasattr(ev, "prev_state"):
- # Load previous state_content.
- # TODO: Should we be pulling this out above?
- cursor = txn.execute(sql, (ev.prev_state,))
- prevs = self.cursor_to_dict(cursor)
- if prevs:
- prev = self._parse_event_from_row(prevs[0])
- ev.prev_content = prev.content
+ ev.hashes = {
+ k: encode_base64(v) for k, v in hashes.items()
+ }
+
+ prevs = self._get_prev_events_and_state(txn, ev.event_id)
+
+ ev.prev_events = [
+ (e_id, h)
+ for e_id, h, is_state in prevs
+ if is_state == 0
+ ]
+
+ ev.auth_events = self._get_auth_events(txn, ev.event_id)
+
+ if hasattr(ev, "state_key"):
+ ev.prev_state = [
+ (e_id, h)
+ for e_id, h, is_state in prevs
+ if is_state == 1
+ ]
+
+ if hasattr(ev, "replaces_state"):
+ # Load previous state_content.
+ # FIXME (erikj): Handle multiple prev_states.
+ cursor = txn.execute(
+ select_event_sql,
+ (ev.replaces_state,)
+ )
+ prevs = self.cursor_to_dict(cursor)
+ if prevs:
+ prev = self._parse_event_from_row(prevs[0])
+ ev.prev_content = prev.content
if not hasattr(ev, "redacted"):
logger.debug("Doesn't have redacted key: %s", ev)
@@ -393,15 +535,16 @@ class SQLBaseStore(object):
if ev.redacted:
# Get the redaction event.
- sql = "SELECT * FROM events WHERE event_id = ?"
- txn.execute(sql, (ev.redacted,))
+ select_event_sql = "SELECT * FROM events WHERE event_id = ?"
+ txn.execute(select_event_sql, (ev.redacted,))
del_evs = self._parse_events_txn(
txn, self.cursor_to_dict(txn)
)
if del_evs:
- prune_event(ev)
+ ev = prune_event(ev)
+ events[i] = ev
ev.redacted_because = del_evs[0]
return events
|