summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
authorMark Haines <mjark@negativecurvature.net>2014-11-11 16:40:50 +0000
committerMark Haines <mjark@negativecurvature.net>2014-11-11 16:40:50 +0000
commita8ceeec0fd512e287cbf71efff42015787517a5d (patch)
tree45643674a31b637799e347f2251c72417e685616 /synapse/storage/_base.py
parentno evil horizontal textarea resizing (diff)
parentFix bugs which broke federation due to changes in function signatures. (diff)
downloadsynapse-a8ceeec0fd512e287cbf71efff42015787517a5d.tar.xz
Merge pull request #12 from matrix-org/federation_authorization
Federation authorization
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py249
1 files changed, 185 insertions, 64 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 65a86e9056..a1ee0318f6 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -14,59 +14,69 @@
 # limitations under the License.
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.errors import StoreError
 from synapse.api.events.utils import prune_event
 from synapse.util.logutils import log_function
+from syutil.base64util import encode_base64
 
 import collections
 import copy
 import json
+import sys
+import time
 
 
 logger = logging.getLogger(__name__)
 
 sql_logger = logging.getLogger("synapse.storage.SQL")
+transaction_logger = logging.getLogger("synapse.storage.txn")
 
 
 class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging to the .execute() method."""
-    __slots__ = ["txn"]
+    __slots__ = ["txn", "name"]
 
-    def __init__(self, txn):
+    def __init__(self, txn, name):
         object.__setattr__(self, "txn", txn)
+        object.__setattr__(self, "name", name)
 
-    def __getattribute__(self, name):
-        if name == "execute":
-            return object.__getattribute__(self, "execute")
-
-        return getattr(object.__getattribute__(self, "txn"), name)
+    def __getattr__(self, name):
+        return getattr(self.txn, name)
 
     def __setattr__(self, name, value):
-        setattr(object.__getattribute__(self, "txn"), name, value)
+        setattr(self.txn, name, value)
 
     def execute(self, sql, *args, **kwargs):
         # TODO(paul): Maybe use 'info' and 'debug' for values?
-        sql_logger.debug("[SQL] %s", sql)
+        sql_logger.debug("[SQL] {%s} %s", self.name, sql)
         try:
             if args and args[0]:
                 values = args[0]
-                sql_logger.debug("[SQL values] " +
-                    ", ".join(("<%s>",) * len(values)), *values)
+                sql_logger.debug(
+                    "[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)),
+                    self.name,
+                    *values
+                )
         except:
             # Don't let logging failures stop SQL from working
             pass
 
-        # TODO(paul): Here would be an excellent place to put some timing
-        #   measurements, and log (warning?) slow queries.
-        return object.__getattribute__(self, "txn").execute(
-            sql, *args, **kwargs
-        )
+        start = time.clock() * 1000
+        try:
+            return self.txn.execute(
+                sql, *args, **kwargs
+            )
+        except:
+                logger.exception("[SQL FAIL] {%s}", self.name)
+                raise
+        finally:
+            end = time.clock() * 1000
+            sql_logger.debug("[SQL time] {%s} %f", self.name, end - start)
 
 
 class SQLBaseStore(object):
+    _TXN_ID = 0
 
     def __init__(self, hs):
         self.hs = hs
@@ -74,10 +84,30 @@ class SQLBaseStore(object):
         self.event_factory = hs.get_event_factory()
         self._clock = hs.get_clock()
 
-    def runInteraction(self, func, *args, **kwargs):
+    def runInteraction(self, desc, func, *args, **kwargs):
         """Wraps the .runInteraction() method on the underlying db_pool."""
         def inner_func(txn, *args, **kwargs):
-            return func(LoggingTransaction(txn), *args, **kwargs)
+            start = time.clock() * 1000
+            txn_id = SQLBaseStore._TXN_ID
+
+            # We don't really need these to be unique, so lets stop it from
+            # growing really large.
+            self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+
+            name = "%s-%x" % (desc, txn_id, )
+
+            transaction_logger.debug("[TXN START] {%s}", name)
+            try:
+                return func(LoggingTransaction(txn, name), *args, **kwargs)
+            except:
+                logger.exception("[TXN FAIL] {%s}", name)
+                raise
+            finally:
+                end = time.clock() * 1000
+                transaction_logger.debug(
+                    "[TXN END] {%s} %f",
+                    name, end - start
+                )
 
         return self._db_pool.runInteraction(inner_func, *args, **kwargs)
 
@@ -113,7 +143,7 @@ class SQLBaseStore(object):
             else:
                 return cursor.fetchall()
 
-        return self.runInteraction(interaction)
+        return self.runInteraction("_execute", interaction)
 
     def _execute_and_decode(self, query, *args):
         return self._execute(self.cursor_to_dict, query, *args)
@@ -130,6 +160,7 @@ class SQLBaseStore(object):
             or_replace : bool; if True performs an INSERT OR REPLACE
         """
         return self.runInteraction(
+            "_simple_insert",
             self._simple_insert_txn, table, values, or_replace=or_replace,
             or_ignore=or_ignore,
         )
@@ -170,7 +201,6 @@ class SQLBaseStore(object):
             table, keyvalues, retcols=retcols, allow_none=allow_none
         )
 
-    @defer.inlineCallbacks
     def _simple_select_one_onecol(self, table, keyvalues, retcol,
                                   allow_none=False):
         """Executes a SELECT query on the named table, which is expected to
@@ -181,19 +211,40 @@ class SQLBaseStore(object):
             keyvalues : dict of column names and values to select the row with
             retcol : string giving the name of the column to return
         """
-        ret = yield self._simple_select_one(
+        return self.runInteraction(
+            "_simple_select_one_onecol",
+            self._simple_select_one_onecol_txn,
+            table, keyvalues, retcol, allow_none=allow_none,
+        )
+
+    def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
+                                      allow_none=False):
+        ret = self._simple_select_onecol_txn(
+            txn,
             table=table,
             keyvalues=keyvalues,
-            retcols=[retcol],
-            allow_none=allow_none
+            retcol=retcol,
         )
 
         if ret:
-            defer.returnValue(ret[retcol])
+            return ret[0]
         else:
-            defer.returnValue(None)
+            if allow_none:
+                return None
+            else:
+                raise StoreError(404, "No row found")
+
+    def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
+        sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
+            "retcol": retcol,
+            "table": table,
+            "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
+        }
+
+        txn.execute(sql, keyvalues.values())
+
+        return [r[0] for r in txn.fetchall()]
 
-    @defer.inlineCallbacks
     def _simple_select_onecol(self, table, keyvalues, retcol):
         """Executes a SELECT query on the named table, which returns a list
         comprising of the values of the named column from the selected rows.
@@ -206,25 +257,33 @@ class SQLBaseStore(object):
         Returns:
             Deferred: Results in a list
         """
-        sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
-            "retcol": retcol,
-            "table": table,
-            "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
-        }
-
-        def func(txn):
-            txn.execute(sql, keyvalues.values())
-            return txn.fetchall()
+        return self.runInteraction(
+            "_simple_select_onecol",
+            self._simple_select_onecol_txn,
+            table, keyvalues, retcol
+        )
 
-        res = yield self.runInteraction(func)
+    def _simple_select_list(self, table, keyvalues, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
 
-        defer.returnValue([r[0] for r in res])
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
+        return self.runInteraction(
+            "_simple_select_list",
+            self._simple_select_list_txn,
+            table, keyvalues, retcols
+        )
 
-    def _simple_select_list(self, table, keyvalues, retcols):
+    def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
+            txn : Transaction object
             table : string giving the table name
             keyvalues : dict of column names and values to select the rows with
             retcols : list of strings giving the names of the columns to return
@@ -232,14 +291,11 @@ class SQLBaseStore(object):
         sql = "SELECT %s FROM %s WHERE %s" % (
             ", ".join(retcols),
             table,
-            " AND ".join("%s = ?" % (k) for k in keyvalues)
+            " AND ".join("%s = ?" % (k, ) for k in keyvalues)
         )
 
-        def func(txn):
-            txn.execute(sql, keyvalues.values())
-            return self.cursor_to_dict(txn)
-
-        return self.runInteraction(func)
+        txn.execute(sql, keyvalues.values())
+        return self.cursor_to_dict(txn)
 
     def _simple_update_one(self, table, keyvalues, updatevalues,
                            retcols=None):
@@ -307,7 +363,7 @@ class SQLBaseStore(object):
                     raise StoreError(500, "More than one row matched")
 
             return ret
-        return self.runInteraction(func)
+        return self.runInteraction("_simple_selectupdate_one", func)
 
     def _simple_delete_one(self, table, keyvalues):
         """Executes a DELETE query on the named table, expecting to delete a
@@ -319,7 +375,7 @@ class SQLBaseStore(object):
         """
         sql = "DELETE FROM %s WHERE %s" % (
             table,
-            " AND ".join("%s = ?" % (k) for k in keyvalues)
+            " AND ".join("%s = ?" % (k, ) for k in keyvalues)
         )
 
         def func(txn):
@@ -328,7 +384,25 @@ class SQLBaseStore(object):
                 raise StoreError(404, "No row found")
             if txn.rowcount > 1:
                 raise StoreError(500, "more than one row matched")
-        return self.runInteraction(func)
+        return self.runInteraction("_simple_delete_one", func)
+
+    def _simple_delete(self, table, keyvalues):
+        """Executes a DELETE query on the named table.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+        """
+
+        return self.runInteraction("_simple_delete", self._simple_delete_txn)
+
+    def _simple_delete_txn(self, txn, table, keyvalues):
+        sql = "DELETE FROM %s WHERE %s" % (
+            table,
+            " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+        )
+
+        return txn.execute(sql, keyvalues.values())
 
     def _simple_max_id(self, table):
         """Executes a SELECT query on the named table, expecting to return the
@@ -346,7 +420,7 @@ class SQLBaseStore(object):
                 return 0
             return max_id
 
-        return self.runInteraction(func)
+        return self.runInteraction("_simple_max_id", func)
 
     def _parse_event_from_row(self, row_dict):
         d = copy.deepcopy({k: v for k, v in row_dict.items()})
@@ -355,6 +429,10 @@ class SQLBaseStore(object):
         d.pop("topological_ordering", None)
         d.pop("processed", None)
         d["origin_server_ts"] = d.pop("ts", 0)
+        replaces_state = d.pop("prev_state", None)
+
+        if replaces_state:
+            d["replaces_state"] = replaces_state
 
         d.update(json.loads(row_dict["unrecognized_keys"]))
         d["content"] = json.loads(d["content"])
@@ -369,23 +447,65 @@ class SQLBaseStore(object):
             **d
         )
 
+    def _get_events_txn(self, txn, event_ids):
+        # FIXME (erikj): This should be batched?
+
+        sql = "SELECT * FROM events WHERE event_id = ?"
+
+        event_rows = []
+        for e_id in event_ids:
+            c = txn.execute(sql, (e_id,))
+            event_rows.extend(self.cursor_to_dict(c))
+
+        return self._parse_events_txn(txn, event_rows)
+
     def _parse_events(self, rows):
-        return self.runInteraction(self._parse_events_txn, rows)
+        return self.runInteraction(
+            "_parse_events", self._parse_events_txn, rows
+        )
 
     def _parse_events_txn(self, txn, rows):
         events = [self._parse_event_from_row(r) for r in rows]
 
-        sql = "SELECT * FROM events WHERE event_id = ?"
+        select_event_sql = "SELECT * FROM events WHERE event_id = ?"
+
+        for i, ev in enumerate(events):
+            signatures = self._get_event_origin_signatures_txn(
+                txn, ev.event_id,
+            )
 
-        for ev in events:
-            if hasattr(ev, "prev_state"):
-                # Load previous state_content.
-                # TODO: Should we be pulling this out above?
-                cursor = txn.execute(sql, (ev.prev_state,))
-                prevs = self.cursor_to_dict(cursor)
-                if prevs:
-                    prev = self._parse_event_from_row(prevs[0])
-                    ev.prev_content = prev.content
+            ev.signatures = {
+                k: encode_base64(v) for k, v in signatures.items()
+            }
+
+            prevs = self._get_prev_events_and_state(txn, ev.event_id)
+
+            ev.prev_events = [
+                (e_id, h)
+                for e_id, h, is_state in prevs
+                if is_state == 0
+            ]
+
+            ev.auth_events = self._get_auth_events(txn, ev.event_id)
+
+            if hasattr(ev, "state_key"):
+                ev.prev_state = [
+                    (e_id, h)
+                    for e_id, h, is_state in prevs
+                    if is_state == 1
+                ]
+
+                if hasattr(ev, "replaces_state"):
+                    # Load previous state_content.
+                    # FIXME (erikj): Handle multiple prev_states.
+                    cursor = txn.execute(
+                        select_event_sql,
+                        (ev.replaces_state,)
+                    )
+                    prevs = self.cursor_to_dict(cursor)
+                    if prevs:
+                        prev = self._parse_event_from_row(prevs[0])
+                        ev.prev_content = prev.content
 
             if not hasattr(ev, "redacted"):
                 logger.debug("Doesn't have redacted key: %s", ev)
@@ -393,15 +513,16 @@ class SQLBaseStore(object):
 
             if ev.redacted:
                 # Get the redaction event.
-                sql = "SELECT * FROM events WHERE event_id = ?"
-                txn.execute(sql, (ev.redacted,))
+                select_event_sql = "SELECT * FROM events WHERE event_id = ?"
+                txn.execute(select_event_sql, (ev.redacted,))
 
                 del_evs = self._parse_events_txn(
                     txn, self.cursor_to_dict(txn)
                 )
 
                 if del_evs:
-                    prune_event(ev)
+                    ev = prune_event(ev)
+                    events[i] = ev
                     ev.redacted_because = del_evs[0]
 
         return events