summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py161
1 files changed, 150 insertions, 11 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index f660fc6eaf..3e1ab0a159 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -77,6 +77,43 @@ class LoggingTransaction(object):
             sql_logger.debug("[SQL time] {%s} %f", self.name, end - start)
 
 
+class PerformanceCounters(object):
+    def __init__(self):
+        self.current_counters = {}
+        self.previous_counters = {}
+
+    def update(self, key, start_time, end_time=None):
+        if end_time is None:
+            end_time = time.time() * 1000
+        duration = end_time - start_time
+        count, cum_time = self.current_counters.get(key, (0, 0))
+        count += 1
+        cum_time += duration
+        self.current_counters[key] = (count, cum_time)
+        return end_time
+
+    def interval(self, interval_duration, limit=3):
+        counters = []
+        for name, (count, cum_time) in self.current_counters.items():
+            prev_count, prev_time = self.previous_counters.get(name, (0, 0))
+            counters.append((
+                (cum_time - prev_time) / interval_duration,
+                count - prev_count,
+                name
+            ))
+
+        self.previous_counters = dict(self.current_counters)
+
+        counters.sort(reverse=True)
+
+        top_n_counters = ", ".join(
+            "%s(%d): %.3f%%" % (name, count, 100 * ratio)
+            for ratio, count, name in counters[:limit]
+        )
+
+        return top_n_counters
+
+
 class SQLBaseStore(object):
     _TXN_ID = 0
 
@@ -85,6 +122,41 @@ class SQLBaseStore(object):
         self._db_pool = hs.get_db_pool()
         self._clock = hs.get_clock()
 
+        self._previous_txn_total_time = 0
+        self._current_txn_total_time = 0
+        self._previous_loop_ts = 0
+        self._txn_perf_counters = PerformanceCounters()
+        self._get_event_counters = PerformanceCounters()
+
+    def start_profiling(self):
+        self._previous_loop_ts = self._clock.time_msec()
+
+        def loop():
+            curr = self._current_txn_total_time
+            prev = self._previous_txn_total_time
+            self._previous_txn_total_time = curr
+
+            time_now = self._clock.time_msec()
+            time_then = self._previous_loop_ts
+            self._previous_loop_ts = time_now
+
+            ratio = (curr - prev)/(time_now - time_then)
+
+            top_three_counters = self._txn_perf_counters.interval(
+                time_now - time_then, limit=3
+            )
+
+            top_3_event_counters = self._get_event_counters.interval(
+                time_now - time_then, limit=3
+            )
+
+            logger.info(
+                "Total database time: %.3f%% {%s} {%s}",
+                ratio * 100, top_three_counters, top_3_event_counters
+            )
+
+        self._clock.looping_call(loop, 10000)
+
     @defer.inlineCallbacks
     def runInteraction(self, desc, func, *args, **kwargs):
         """Wraps the .runInteraction() method on the underlying db_pool."""
@@ -94,7 +166,7 @@ class SQLBaseStore(object):
             with LoggingContext("runInteraction") as context:
                 current_context.copy_to(context)
                 start = time.time() * 1000
-                txn_id = SQLBaseStore._TXN_ID
+                txn_id = self._TXN_ID
 
                 # We don't really need these to be unique, so lets stop it from
                 # growing really large.
@@ -114,6 +186,10 @@ class SQLBaseStore(object):
                         "[TXN END] {%s} %f",
                         name, end - start
                     )
+
+                    self._current_txn_total_time += end - start
+                    self._txn_perf_counters.update(desc, start, end)
+
         with PreserveLoggingContext():
             result = yield self._db_pool.runInteraction(
                 inner_func, *args, **kwargs
@@ -193,6 +269,50 @@ class SQLBaseStore(object):
         txn.execute(sql, values.values())
         return txn.lastrowid
 
+    def _simple_upsert(self, table, keyvalues, values):
+        """
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+        Returns: A deferred
+        """
+        return self.runInteraction(
+            "_simple_upsert",
+            self._simple_upsert_txn, table, keyvalues, values
+        )
+
+    def _simple_upsert_txn(self, txn, table, keyvalues, values):
+        # Try to update
+        sql = "UPDATE %s SET %s WHERE %s" % (
+            table,
+            ", ".join("%s = ?" % (k,) for k in values),
+            " AND ".join("%s = ?" % (k,) for k in keyvalues)
+        )
+        sqlargs = values.values() + keyvalues.values()
+        logger.debug(
+            "[SQL] %s Args=%s",
+            sql, sqlargs,
+        )
+
+        txn.execute(sql, sqlargs)
+        if txn.rowcount == 0:
+            # We didn't update and rows so insert a new one
+            allvalues = {}
+            allvalues.update(keyvalues)
+            allvalues.update(values)
+
+            sql = "INSERT INTO %s (%s) VALUES (%s)" % (
+                table,
+                ", ".join(k for k in allvalues),
+                ", ".join("?" for _ in allvalues)
+            )
+            logger.debug(
+                "[SQL] %s Args=%s",
+                sql, keyvalues.values(),
+            )
+            txn.execute(sql, allvalues.values())
+
     def _simple_select_one(self, table, keyvalues, retcols,
                            allow_none=False):
         """Executes a SELECT query on the named table, which is expected to
@@ -344,8 +464,8 @@ class SQLBaseStore(object):
         if updatevalues:
             update_sql = "UPDATE %s SET %s WHERE %s" % (
                 table,
-                ", ".join("%s = ?" % (k) for k in updatevalues),
-                " AND ".join("%s = ?" % (k) for k in keyvalues)
+                ", ".join("%s = ?" % (k,) for k in updatevalues),
+                " AND ".join("%s = ?" % (k,) for k in keyvalues)
             )
 
         def func(txn):
@@ -458,14 +578,18 @@ class SQLBaseStore(object):
         return [e for e in events if e]
 
     def _get_event_txn(self, txn, event_id, check_redacted=True,
-                       get_prev_content=False):
+                       get_prev_content=False, allow_rejected=False):
         sql = (
-            "SELECT internal_metadata, json, r.event_id FROM event_json as e "
+            "SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
+            "FROM event_json as e "
             "LEFT JOIN redactions as r ON e.event_id = r.redacts "
+            "LEFT JOIN rejections as rej on rej.event_id = e.event_id  "
             "WHERE e.event_id = ? "
             "LIMIT 1 "
         )
 
+        start_time = time.time() * 1000
+
         txn.execute(sql, (event_id,))
 
         res = txn.fetchone()
@@ -473,20 +597,33 @@ class SQLBaseStore(object):
         if not res:
             return None
 
-        internal_metadata, js, redacted = res
+        internal_metadata, js, redacted, rejected_reason = res
 
-        return self._get_event_from_row_txn(
-            txn, internal_metadata, js, redacted,
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-        )
+        self._get_event_counters.update("select_event", start_time)
+
+        if allow_rejected or not rejected_reason:
+            return self._get_event_from_row_txn(
+                txn, internal_metadata, js, redacted,
+                check_redacted=check_redacted,
+                get_prev_content=get_prev_content,
+            )
+        else:
+            return None
 
     def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
                                 check_redacted=True, get_prev_content=False):
+
+        start_time = time.time() * 1000
+        update_counter = self._get_event_counters.update
+
         d = json.loads(js)
+        start_time = update_counter("decode_json", start_time)
+
         internal_metadata = json.loads(internal_metadata)
+        start_time = update_counter("decode_internal", start_time)
 
         ev = FrozenEvent(d, internal_metadata_dict=internal_metadata)
+        start_time = update_counter("build_frozen_event", start_time)
 
         if check_redacted and redacted:
             ev = prune_event(ev)
@@ -502,6 +639,7 @@ class SQLBaseStore(object):
 
             if because:
                 ev.unsigned["redacted_because"] = because
+            start_time = update_counter("redact_event", start_time)
 
         if get_prev_content and "replaces_state" in ev.unsigned:
             prev = self._get_event_txn(
@@ -511,6 +649,7 @@ class SQLBaseStore(object):
             )
             if prev:
                 ev.unsigned["prev_content"] = prev.get_dict()["content"]
+            start_time = update_counter("get_prev_content", start_time)
 
         return ev