summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py165
1 files changed, 124 insertions, 41 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 183a752387..7dc67ecd57 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,13 +15,11 @@
 import logging
 
 from synapse.api.errors import StoreError
-from synapse.util.logutils import log_function
-from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 from synapse.util.caches.dictionary_cache import DictionaryCache
 from synapse.util.caches.descriptors import Cache
 import synapse.metrics
 
-from util.id_generators import IdGenerator, StreamIdGenerator
 
 from twisted.internet import defer
 
@@ -175,16 +173,6 @@ class SQLBaseStore(object):
 
         self.database_engine = hs.database_engine
 
-        self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
-        self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
-        self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
-        self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
-        self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
-        self._pushers_id_gen = IdGenerator("pushers", "id", self)
-        self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
-        self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
-        self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
-
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
 
@@ -197,7 +185,7 @@ class SQLBaseStore(object):
             time_then = self._previous_loop_ts
             self._previous_loop_ts = time_now
 
-            ratio = (curr - prev)/(time_now - time_then)
+            ratio = (curr - prev) / (time_now - time_then)
 
             top_three_counters = self._txn_perf_counters.interval(
                 time_now - time_then, limit=3
@@ -310,10 +298,10 @@ class SQLBaseStore(object):
                     func, *args, **kwargs
                 )
 
-        result = yield preserve_context_over_fn(
-            self._db_pool.runWithConnection,
-            inner_func, *args, **kwargs
-        )
+        with PreserveLoggingContext():
+            result = yield self._db_pool.runWithConnection(
+                inner_func, *args, **kwargs
+            )
 
         for after_callback, after_args in after_callbacks:
             after_callback(*after_args)
@@ -338,14 +326,15 @@ class SQLBaseStore(object):
 
                 return func(conn, *args, **kwargs)
 
-        result = yield preserve_context_over_fn(
-            self._db_pool.runWithConnection,
-            inner_func, *args, **kwargs
-        )
+        with PreserveLoggingContext():
+            result = yield self._db_pool.runWithConnection(
+                inner_func, *args, **kwargs
+            )
 
         defer.returnValue(result)
 
-    def cursor_to_dict(self, cursor):
+    @staticmethod
+    def cursor_to_dict(cursor):
         """Converts a SQL cursor into an list of dicts.
 
         Args:
@@ -402,8 +391,8 @@ class SQLBaseStore(object):
             if not or_ignore:
                 raise
 
-    @log_function
-    def _simple_insert_txn(self, txn, table, values):
+    @staticmethod
+    def _simple_insert_txn(txn, table, values):
         keys, vals = zip(*values.items())
 
         sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -414,7 +403,8 @@ class SQLBaseStore(object):
 
         txn.execute(sql, vals)
 
-    def _simple_insert_many_txn(self, txn, table, values):
+    @staticmethod
+    def _simple_insert_many_txn(txn, table, values):
         if not values:
             return
 
@@ -537,9 +527,10 @@ class SQLBaseStore(object):
             table, keyvalues, retcol, allow_none=allow_none,
         )
 
-    def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
+    @classmethod
+    def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
                                       allow_none=False):
-        ret = self._simple_select_onecol_txn(
+        ret = cls._simple_select_onecol_txn(
             txn,
             table=table,
             keyvalues=keyvalues,
@@ -554,7 +545,8 @@ class SQLBaseStore(object):
             else:
                 raise StoreError(404, "No row found")
 
-    def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
+    @staticmethod
+    def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
         sql = (
             "SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
         ) % {
@@ -603,7 +595,8 @@ class SQLBaseStore(object):
             table, keyvalues, retcols
         )
 
-    def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
+    @classmethod
+    def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
@@ -627,7 +620,83 @@ class SQLBaseStore(object):
             )
             txn.execute(sql)
 
-        return self.cursor_to_dict(txn)
+        return cls.cursor_to_dict(txn)
+
+    @defer.inlineCallbacks
+    def _simple_select_many_batch(self, table, column, iterable, retcols,
+                                  keyvalues={}, desc="_simple_select_many_batch",
+                                  batch_size=100):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
+        results = []
+
+        if not iterable:
+            defer.returnValue(results)
+
+        chunks = [
+            iterable[i:i + batch_size]
+            for i in xrange(0, len(iterable), batch_size)
+        ]
+        for chunk in chunks:
+            rows = yield self.runInteraction(
+                desc,
+                self._simple_select_many_txn,
+                table, column, chunk, keyvalues, retcols
+            )
+
+            results.extend(rows)
+
+        defer.returnValue(results)
+
+    @classmethod
+    def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            txn : Transaction object
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
+        if not iterable:
+            return []
+
+        sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
+
+        clauses = []
+        values = []
+        clauses.append(
+            "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
+        )
+        values.extend(iterable)
+
+        for key, value in keyvalues.items():
+            clauses.append("%s = ?" % (key,))
+            values.append(value)
+
+        if clauses:
+            sql = "%s WHERE %s" % (
+                sql,
+                " AND ".join(clauses),
+            )
+
+        txn.execute(sql, values)
+        return cls.cursor_to_dict(txn)
 
     def _simple_update_one(self, table, keyvalues, updatevalues,
                            desc="_simple_update_one"):
@@ -654,7 +723,8 @@ class SQLBaseStore(object):
             table, keyvalues, updatevalues,
         )
 
-    def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
+    @staticmethod
+    def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
         update_sql = "UPDATE %s SET %s WHERE %s" % (
             table,
             ", ".join("%s = ?" % (k,) for k in updatevalues),
@@ -671,7 +741,8 @@ class SQLBaseStore(object):
         if txn.rowcount > 1:
             raise StoreError(500, "More than one row matched")
 
-    def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
+    @staticmethod
+    def _simple_select_one_txn(txn, table, keyvalues, retcols,
                                allow_none=False):
         select_sql = "SELECT %s FROM %s WHERE %s" % (
             ", ".join(retcols),
@@ -699,20 +770,32 @@ class SQLBaseStore(object):
             table : string giving the table name
             keyvalues : dict of column names and values to select the row with
         """
+        return self.runInteraction(
+            desc, self._simple_delete_one_txn, table, keyvalues
+        )
+
+    @staticmethod
+    def _simple_delete_one_txn(txn, table, keyvalues):
+        """Executes a DELETE query on the named table, expecting to delete a
+        single row.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+        """
         sql = "DELETE FROM %s WHERE %s" % (
             table,
             " AND ".join("%s = ?" % (k, ) for k in keyvalues)
         )
 
-        def func(txn):
-            txn.execute(sql, keyvalues.values())
-            if txn.rowcount == 0:
-                raise StoreError(404, "No row found")
-            if txn.rowcount > 1:
-                raise StoreError(500, "more than one row matched")
-        return self.runInteraction(desc, func)
+        txn.execute(sql, keyvalues.values())
+        if txn.rowcount == 0:
+            raise StoreError(404, "No row found")
+        if txn.rowcount > 1:
+            raise StoreError(500, "more than one row matched")
 
-    def _simple_delete_txn(self, txn, table, keyvalues):
+    @staticmethod
+    def _simple_delete_txn(txn, table, keyvalues):
         sql = "DELETE FROM %s WHERE %s" % (
             table,
             " AND ".join("%s = ?" % (k, ) for k in keyvalues)