summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
authorAmber Brown <hawkowl@atleastfornow.net>2019-04-03 20:07:29 +1100
committerRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2019-04-03 10:07:29 +0100
commit7efd1d87c2c424365c99ba6103135edb1845fd88 (patch)
treeeadf93e88f277cca7f6fb694c8457ca5c73f5646 /synapse/storage/_base.py
parentMerge pull request #4991 from matrix-org/erikj/stagger_push_startup (diff)
downloadsynapse-7efd1d87c2c424365c99ba6103135edb1845fd88.tar.xz
Run black on the rest of the storage module (#4996)
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py396
1 files changed, 173 insertions, 223 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 653b32fbf5..131820628a 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -41,7 +41,7 @@ try:
     MAX_TXN_ID = sys.maxint - 1
 except AttributeError:
     # python 3 does not have a maximum int value
-    MAX_TXN_ID = 2**63 - 1
+    MAX_TXN_ID = 2 ** 63 - 1
 
 sql_logger = logging.getLogger("synapse.storage.SQL")
 transaction_logger = logging.getLogger("synapse.storage.txn")
@@ -76,12 +76,18 @@ class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
     method."""
+
     __slots__ = [
-        "txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
+        "txn",
+        "name",
+        "database_engine",
+        "after_callbacks",
+        "exception_callbacks",
     ]
 
-    def __init__(self, txn, name, database_engine, after_callbacks,
-                 exception_callbacks):
+    def __init__(
+        self, txn, name, database_engine, after_callbacks, exception_callbacks
+    ):
         object.__setattr__(self, "txn", txn)
         object.__setattr__(self, "name", name)
         object.__setattr__(self, "database_engine", database_engine)
@@ -110,6 +116,7 @@ class LoggingTransaction(object):
     def execute_batch(self, sql, args):
         if isinstance(self.database_engine, PostgresEngine):
             from psycopg2.extras import execute_batch
+
             self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
         else:
             for val in args:
@@ -134,10 +141,7 @@ class LoggingTransaction(object):
         sql = self.database_engine.convert_param_style(sql)
         if args:
             try:
-                sql_logger.debug(
-                    "[SQL values] {%s} %r",
-                    self.name, args[0]
-                )
+                sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
             except Exception:
                 # Don't let logging failures stop SQL from working
                 pass
@@ -145,9 +149,7 @@ class LoggingTransaction(object):
         start = time.time()
 
         try:
-            return func(
-                sql, *args
-            )
+            return func(sql, *args)
         except Exception as e:
             logger.debug("[SQL FAIL] {%s} %s", self.name, e)
             raise
@@ -176,11 +178,9 @@ class PerformanceCounters(object):
         counters = []
         for name, (count, cum_time) in iteritems(self.current_counters):
             prev_count, prev_time = self.previous_counters.get(name, (0, 0))
-            counters.append((
-                (cum_time - prev_time) / interval_duration,
-                count - prev_count,
-                name
-            ))
+            counters.append(
+                ((cum_time - prev_time) / interval_duration, count - prev_count, name)
+            )
 
         self.previous_counters = dict(self.current_counters)
 
@@ -212,8 +212,9 @@ class SQLBaseStore(object):
         self._txn_perf_counters = PerformanceCounters()
         self._get_event_counters = PerformanceCounters()
 
-        self._get_event_cache = Cache("*getEvent*", keylen=3,
-                                      max_entries=hs.config.event_cache_size)
+        self._get_event_cache = Cache(
+            "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
+        )
 
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list = []
@@ -239,7 +240,7 @@ class SQLBaseStore(object):
                 0.0,
                 run_as_background_process,
                 "upsert_safety_check",
-                self._check_safe_to_upsert
+                self._check_safe_to_upsert,
             )
 
     @defer.inlineCallbacks
@@ -271,7 +272,7 @@ class SQLBaseStore(object):
                 15.0,
                 run_as_background_process,
                 "upsert_safety_check",
-                self._check_safe_to_upsert
+                self._check_safe_to_upsert,
             )
 
     def start_profiling(self):
@@ -298,13 +299,16 @@ class SQLBaseStore(object):
 
             perf_logger.info(
                 "Total database time: %.3f%% {%s} {%s}",
-                ratio * 100, top_three_counters, top_3_event_counters
+                ratio * 100,
+                top_three_counters,
+                top_3_event_counters,
             )
 
         self._clock.looping_call(loop, 10000)
 
-    def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
-                         func, *args, **kwargs):
+    def _new_transaction(
+        self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
+    ):
         start = time.time()
         txn_id = self._TXN_ID
 
@@ -312,7 +316,7 @@ class SQLBaseStore(object):
         # growing really large.
         self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
 
-        name = "%s-%x" % (desc, txn_id, )
+        name = "%s-%x" % (desc, txn_id)
 
         transaction_logger.debug("[TXN START] {%s}", name)
 
@@ -323,7 +327,10 @@ class SQLBaseStore(object):
                 try:
                     txn = conn.cursor()
                     txn = LoggingTransaction(
-                        txn, name, self.database_engine, after_callbacks,
+                        txn,
+                        name,
+                        self.database_engine,
+                        after_callbacks,
                         exception_callbacks,
                     )
                     r = func(txn, *args, **kwargs)
@@ -334,7 +341,10 @@ class SQLBaseStore(object):
                     # transaction.
                     logger.warning(
                         "[TXN OPERROR] {%s} %s %d/%d",
-                        name, exception_to_unicode(e), i, N
+                        name,
+                        exception_to_unicode(e),
+                        i,
+                        N,
                     )
                     if i < N:
                         i += 1
@@ -342,8 +352,7 @@ class SQLBaseStore(object):
                             conn.rollback()
                         except self.database_engine.module.Error as e1:
                             logger.warning(
-                                "[TXN EROLL] {%s} %s",
-                                name, exception_to_unicode(e1),
+                                "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
                             )
                         continue
                     raise
@@ -357,7 +366,8 @@ class SQLBaseStore(object):
                             except self.database_engine.module.Error as e1:
                                 logger.warning(
                                     "[TXN EROLL] {%s} %s",
-                                    name, exception_to_unicode(e1),
+                                    name,
+                                    exception_to_unicode(e1),
                                 )
                             continue
                     raise
@@ -396,16 +406,17 @@ class SQLBaseStore(object):
         exception_callbacks = []
 
         if LoggingContext.current_context() == LoggingContext.sentinel:
-            logger.warn(
-                "Starting db txn '%s' from sentinel context",
-                desc,
-            )
+            logger.warn("Starting db txn '%s' from sentinel context", desc)
 
         try:
             result = yield self.runWithConnection(
                 self._new_transaction,
-                desc, after_callbacks, exception_callbacks, func,
-                *args, **kwargs
+                desc,
+                after_callbacks,
+                exception_callbacks,
+                func,
+                *args,
+                **kwargs
             )
 
             for after_callback, after_args, after_kwargs in after_callbacks:
@@ -434,7 +445,7 @@ class SQLBaseStore(object):
         parent_context = LoggingContext.current_context()
         if parent_context == LoggingContext.sentinel:
             logger.warn(
-                "Starting db connection from sentinel context: metrics will be lost",
+                "Starting db connection from sentinel context: metrics will be lost"
             )
             parent_context = None
 
@@ -453,9 +464,7 @@ class SQLBaseStore(object):
                 return func(conn, *args, **kwargs)
 
         with PreserveLoggingContext():
-            result = yield self._db_pool.runWithConnection(
-                inner_func, *args, **kwargs
-            )
+            result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
 
         defer.returnValue(result)
 
@@ -469,9 +478,7 @@ class SQLBaseStore(object):
             A list of dicts where the key is the column header.
         """
         col_headers = list(intern(str(column[0])) for column in cursor.description)
-        results = list(
-            dict(zip(col_headers, row)) for row in cursor
-        )
+        results = list(dict(zip(col_headers, row)) for row in cursor)
         return results
 
     def _execute(self, desc, decoder, query, *args):
@@ -485,6 +492,7 @@ class SQLBaseStore(object):
         Returns:
             The result of decoder(results)
         """
+
         def interaction(txn):
             txn.execute(query, args)
             if decoder:
@@ -498,8 +506,7 @@ class SQLBaseStore(object):
     # no complex WHERE clauses, just a dict of values for columns.
 
     @defer.inlineCallbacks
-    def _simple_insert(self, table, values, or_ignore=False,
-                       desc="_simple_insert"):
+    def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"):
         """Executes an INSERT query on the named table.
 
         Args:
@@ -511,10 +518,7 @@ class SQLBaseStore(object):
             `or_ignore` is True
         """
         try:
-            yield self.runInteraction(
-                desc,
-                self._simple_insert_txn, table, values,
-            )
+            yield self.runInteraction(desc, self._simple_insert_txn, table, values)
         except self.database_engine.module.IntegrityError:
             # We have to do or_ignore flag at this layer, since we can't reuse
             # a cursor after we receive an error from the db.
@@ -530,15 +534,13 @@ class SQLBaseStore(object):
         sql = "INSERT INTO %s (%s) VALUES(%s)" % (
             table,
             ", ".join(k for k in keys),
-            ", ".join("?" for _ in keys)
+            ", ".join("?" for _ in keys),
         )
 
         txn.execute(sql, vals)
 
     def _simple_insert_many(self, table, values, desc):
-        return self.runInteraction(
-            desc, self._simple_insert_many_txn, table, values
-        )
+        return self.runInteraction(desc, self._simple_insert_many_txn, table, values)
 
     @staticmethod
     def _simple_insert_many_txn(txn, table, values):
@@ -553,24 +555,18 @@ class SQLBaseStore(object):
         #
         # The sort is to ensure that we don't rely on dictionary iteration
         # order.
-        keys, vals = zip(*[
-            zip(
-                *(sorted(i.items(), key=lambda kv: kv[0]))
-            )
-            for i in values
-            if i
-        ])
+        keys, vals = zip(
+            *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
+        )
 
         for k in keys:
             if k != keys[0]:
-                raise RuntimeError(
-                    "All items must have the same keys"
-                )
+                raise RuntimeError("All items must have the same keys")
 
         sql = "INSERT INTO %s (%s) VALUES(%s)" % (
             table,
             ", ".join(k for k in keys[0]),
-            ", ".join("?" for _ in keys[0])
+            ", ".join("?" for _ in keys[0]),
         )
 
         txn.executemany(sql, vals)
@@ -583,7 +579,7 @@ class SQLBaseStore(object):
         values,
         insertion_values={},
         desc="_simple_upsert",
-        lock=True
+        lock=True,
     ):
         """
 
@@ -635,13 +631,7 @@ class SQLBaseStore(object):
                 )
 
     def _simple_upsert_txn(
-        self,
-        txn,
-        table,
-        keyvalues,
-        values,
-        insertion_values={},
-        lock=True,
+        self, txn, table, keyvalues, values, insertion_values={}, lock=True
     ):
         """
         Pick the UPSERT method which works best on the platform. Either the
@@ -665,11 +655,7 @@ class SQLBaseStore(object):
             and table not in self._unsafe_to_upsert_tables
         ):
             return self._simple_upsert_txn_native_upsert(
-                txn,
-                table,
-                keyvalues,
-                values,
-                insertion_values=insertion_values,
+                txn, table, keyvalues, values, insertion_values=insertion_values
             )
         else:
             return self._simple_upsert_txn_emulated(
@@ -714,7 +700,7 @@ class SQLBaseStore(object):
             # SELECT instead to see if it exists.
             sql = "SELECT 1 FROM %s WHERE %s" % (
                 table,
-                " AND ".join(_getwhere(k) for k in keyvalues)
+                " AND ".join(_getwhere(k) for k in keyvalues),
             )
             sqlargs = list(keyvalues.values())
             txn.execute(sql, sqlargs)
@@ -726,7 +712,7 @@ class SQLBaseStore(object):
             sql = "UPDATE %s SET %s WHERE %s" % (
                 table,
                 ", ".join("%s = ?" % (k,) for k in values),
-                " AND ".join(_getwhere(k) for k in keyvalues)
+                " AND ".join(_getwhere(k) for k in keyvalues),
             )
             sqlargs = list(values.values()) + list(keyvalues.values())
 
@@ -773,19 +759,14 @@ class SQLBaseStore(object):
             latter = "NOTHING"
         else:
             allvalues.update(values)
-            latter = (
-                "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
-            )
+            latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
 
-        sql = (
-            "INSERT INTO %s (%s) VALUES (%s) "
-            "ON CONFLICT (%s) DO %s"
-        ) % (
+        sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
             table,
             ", ".join(k for k in allvalues),
             ", ".join("?" for _ in allvalues),
             ", ".join(k for k in keyvalues),
-            latter
+            latter,
         )
         txn.execute(sql, list(allvalues.values()))
 
@@ -870,8 +851,8 @@ class SQLBaseStore(object):
             latter = "NOTHING"
             value_values = [() for x in range(len(key_values))]
         else:
-            latter = (
-                "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in value_names)
+            latter = "UPDATE SET " + ", ".join(
+                k + "=EXCLUDED." + k for k in value_names
             )
 
         sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
@@ -889,8 +870,9 @@ class SQLBaseStore(object):
 
         return txn.execute_batch(sql, args)
 
-    def _simple_select_one(self, table, keyvalues, retcols,
-                           allow_none=False, desc="_simple_select_one"):
+    def _simple_select_one(
+        self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"
+    ):
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning multiple columns from it.
 
@@ -903,14 +885,17 @@ class SQLBaseStore(object):
               statement returns no rows
         """
         return self.runInteraction(
-            desc,
-            self._simple_select_one_txn,
-            table, keyvalues, retcols, allow_none,
+            desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none
         )
 
-    def _simple_select_one_onecol(self, table, keyvalues, retcol,
-                                  allow_none=False,
-                                  desc="_simple_select_one_onecol"):
+    def _simple_select_one_onecol(
+        self,
+        table,
+        keyvalues,
+        retcol,
+        allow_none=False,
+        desc="_simple_select_one_onecol",
+    ):
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning a single column from it.
 
@@ -922,17 +907,18 @@ class SQLBaseStore(object):
         return self.runInteraction(
             desc,
             self._simple_select_one_onecol_txn,
-            table, keyvalues, retcol, allow_none=allow_none,
+            table,
+            keyvalues,
+            retcol,
+            allow_none=allow_none,
         )
 
     @classmethod
-    def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
-                                      allow_none=False):
+    def _simple_select_one_onecol_txn(
+        cls, txn, table, keyvalues, retcol, allow_none=False
+    ):
         ret = cls._simple_select_onecol_txn(
-            txn,
-            table=table,
-            keyvalues=keyvalues,
-            retcol=retcol,
+            txn, table=table, keyvalues=keyvalues, retcol=retcol
         )
 
         if ret:
@@ -945,12 +931,7 @@ class SQLBaseStore(object):
 
     @staticmethod
     def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
-        sql = (
-            "SELECT %(retcol)s FROM %(table)s"
-        ) % {
-            "retcol": retcol,
-            "table": table,
-        }
+        sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
 
         if keyvalues:
             sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
@@ -960,8 +941,9 @@ class SQLBaseStore(object):
 
         return [r[0] for r in txn]
 
-    def _simple_select_onecol(self, table, keyvalues, retcol,
-                              desc="_simple_select_onecol"):
+    def _simple_select_onecol(
+        self, table, keyvalues, retcol, desc="_simple_select_onecol"
+    ):
         """Executes a SELECT query on the named table, which returns a list
         comprising of the values of the named column from the selected rows.
 
@@ -974,13 +956,12 @@ class SQLBaseStore(object):
             Deferred: Results in a list
         """
         return self.runInteraction(
-            desc,
-            self._simple_select_onecol_txn,
-            table, keyvalues, retcol
+            desc, self._simple_select_onecol_txn, table, keyvalues, retcol
         )
 
-    def _simple_select_list(self, table, keyvalues, retcols,
-                            desc="_simple_select_list"):
+    def _simple_select_list(
+        self, table, keyvalues, retcols, desc="_simple_select_list"
+    ):
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
@@ -994,9 +975,7 @@ class SQLBaseStore(object):
             defer.Deferred: resolves to list[dict[str, Any]]
         """
         return self.runInteraction(
-            desc,
-            self._simple_select_list_txn,
-            table, keyvalues, retcols
+            desc, self._simple_select_list_txn, table, keyvalues, retcols
         )
 
     @classmethod
@@ -1016,22 +995,26 @@ class SQLBaseStore(object):
             sql = "SELECT %s FROM %s WHERE %s" % (
                 ", ".join(retcols),
                 table,
-                " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+                " AND ".join("%s = ?" % (k,) for k in keyvalues),
             )
             txn.execute(sql, list(keyvalues.values()))
         else:
-            sql = "SELECT %s FROM %s" % (
-                ", ".join(retcols),
-                table
-            )
+            sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
             txn.execute(sql)
 
         return cls.cursor_to_dict(txn)
 
     @defer.inlineCallbacks
-    def _simple_select_many_batch(self, table, column, iterable, retcols,
-                                  keyvalues={}, desc="_simple_select_many_batch",
-                                  batch_size=100):
+    def _simple_select_many_batch(
+        self,
+        table,
+        column,
+        iterable,
+        retcols,
+        keyvalues={},
+        desc="_simple_select_many_batch",
+        batch_size=100,
+    ):
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
@@ -1053,14 +1036,17 @@ class SQLBaseStore(object):
         it_list = list(iterable)
 
         chunks = [
-            it_list[i:i + batch_size]
-            for i in range(0, len(it_list), batch_size)
+            it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
         ]
         for chunk in chunks:
             rows = yield self.runInteraction(
                 desc,
                 self._simple_select_many_txn,
-                table, column, chunk, keyvalues, retcols
+                table,
+                column,
+                chunk,
+                keyvalues,
+                retcols,
             )
 
             results.extend(rows)
@@ -1089,9 +1075,7 @@ class SQLBaseStore(object):
 
         clauses = []
         values = []
-        clauses.append(
-            "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
-        )
+        clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
         values.extend(iterable)
 
         for key, value in iteritems(keyvalues):
@@ -1099,19 +1083,14 @@ class SQLBaseStore(object):
             values.append(value)
 
         if clauses:
-            sql = "%s WHERE %s" % (
-                sql,
-                " AND ".join(clauses),
-            )
+            sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
 
         txn.execute(sql, values)
         return cls.cursor_to_dict(txn)
 
     def _simple_update(self, table, keyvalues, updatevalues, desc):
         return self.runInteraction(
-            desc,
-            self._simple_update_txn,
-            table, keyvalues, updatevalues,
+            desc, self._simple_update_txn, table, keyvalues, updatevalues
         )
 
     @staticmethod
@@ -1127,15 +1106,13 @@ class SQLBaseStore(object):
             where,
         )
 
-        txn.execute(
-            update_sql,
-            list(updatevalues.values()) + list(keyvalues.values())
-        )
+        txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
 
         return txn.rowcount
 
-    def _simple_update_one(self, table, keyvalues, updatevalues,
-                           desc="_simple_update_one"):
+    def _simple_update_one(
+        self, table, keyvalues, updatevalues, desc="_simple_update_one"
+    ):
         """Executes an UPDATE query on the named table, setting new values for
         columns in a row matching the key values.
 
@@ -1154,9 +1131,7 @@ class SQLBaseStore(object):
         the update column in the 'keyvalues' dict as well.
         """
         return self.runInteraction(
-            desc,
-            self._simple_update_one_txn,
-            table, keyvalues, updatevalues,
+            desc, self._simple_update_one_txn, table, keyvalues, updatevalues
         )
 
     @classmethod
@@ -1169,12 +1144,11 @@ class SQLBaseStore(object):
             raise StoreError(500, "More than one row matched (%s)" % (table,))
 
     @staticmethod
-    def _simple_select_one_txn(txn, table, keyvalues, retcols,
-                               allow_none=False):
+    def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
         select_sql = "SELECT %s FROM %s WHERE %s" % (
             ", ".join(retcols),
             table,
-            " AND ".join("%s = ?" % (k,) for k in keyvalues)
+            " AND ".join("%s = ?" % (k,) for k in keyvalues),
         )
 
         txn.execute(select_sql, list(keyvalues.values()))
@@ -1197,9 +1171,7 @@ class SQLBaseStore(object):
             table : string giving the table name
             keyvalues : dict of column names and values to select the row with
         """
-        return self.runInteraction(
-            desc, self._simple_delete_one_txn, table, keyvalues
-        )
+        return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues)
 
     @staticmethod
     def _simple_delete_one_txn(txn, table, keyvalues):
@@ -1212,7 +1184,7 @@ class SQLBaseStore(object):
         """
         sql = "DELETE FROM %s WHERE %s" % (
             table,
-            " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+            " AND ".join("%s = ?" % (k,) for k in keyvalues),
         )
 
         txn.execute(sql, list(keyvalues.values()))
@@ -1222,15 +1194,13 @@ class SQLBaseStore(object):
             raise StoreError(500, "More than one row matched (%s)" % (table,))
 
     def _simple_delete(self, table, keyvalues, desc):
-        return self.runInteraction(
-            desc, self._simple_delete_txn, table, keyvalues
-        )
+        return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues)
 
     @staticmethod
     def _simple_delete_txn(txn, table, keyvalues):
         sql = "DELETE FROM %s WHERE %s" % (
             table,
-            " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+            " AND ".join("%s = ?" % (k,) for k in keyvalues),
         )
 
         return txn.execute(sql, list(keyvalues.values()))
@@ -1260,9 +1230,7 @@ class SQLBaseStore(object):
 
         clauses = []
         values = []
-        clauses.append(
-            "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
-        )
+        clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
         values.extend(iterable)
 
         for key, value in iteritems(keyvalues):
@@ -1270,14 +1238,12 @@ class SQLBaseStore(object):
             values.append(value)
 
         if clauses:
-            sql = "%s WHERE %s" % (
-                sql,
-                " AND ".join(clauses),
-            )
+            sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
         return txn.execute(sql, values)
 
-    def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
-                        max_value, limit=100000):
+    def _get_cache_dict(
+        self, db_conn, table, entity_column, stream_column, max_value, limit=100000
+    ):
         # Fetch a mapping of room_id -> max stream position for "recent" rooms.
         # It doesn't really matter how many we get, the StreamChangeCache will
         # do the right thing to ensure it respects the max size of cache.
@@ -1297,10 +1263,7 @@ class SQLBaseStore(object):
         txn = db_conn.cursor()
         txn.execute(sql, (int(max_value),))
 
-        cache = {
-            row[0]: int(row[1])
-            for row in txn
-        }
+        cache = {row[0]: int(row[1]) for row in txn}
 
         txn.close()
 
@@ -1342,9 +1305,7 @@ class SQLBaseStore(object):
         # be safe.
         for chunk in batch_iter(members_changed, 50):
             keys = itertools.chain([room_id], chunk)
-            self._send_invalidation_to_replication(
-                txn, _CURRENT_STATE_CACHE_NAME, keys,
-            )
+            self._send_invalidation_to_replication(txn, _CURRENT_STATE_CACHE_NAME, keys)
 
     def _invalidate_state_caches(self, room_id, members_changed):
         """Invalidates caches that are based on the current state, but does
@@ -1356,22 +1317,12 @@ class SQLBaseStore(object):
                 changed
         """
         for host in set(get_domain_from_id(u) for u in members_changed):
-            self._attempt_to_invalidate_cache(
-                "is_host_joined", (room_id, host,),
-            )
-            self._attempt_to_invalidate_cache(
-                "was_host_joined", (room_id, host,),
-            )
+            self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
+            self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
 
-        self._attempt_to_invalidate_cache(
-            "get_users_in_room", (room_id,),
-        )
-        self._attempt_to_invalidate_cache(
-            "get_room_summary", (room_id,),
-        )
-        self._attempt_to_invalidate_cache(
-            "get_current_state_ids", (room_id,),
-        )
+        self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
+        self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
+        self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
 
     def _attempt_to_invalidate_cache(self, cache_name, key):
         """Attempts to invalidate the cache of the given name, ignoring if the
@@ -1419,7 +1370,7 @@ class SQLBaseStore(object):
                     "cache_func": cache_name,
                     "keys": list(keys),
                     "invalidation_ts": self.clock.time_msec(),
-                }
+                },
             )
 
     def get_all_updated_caches(self, last_id, current_id, limit):
@@ -1435,11 +1386,10 @@ class SQLBaseStore(object):
                 " FROM cache_invalidation_stream"
                 " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
             )
-            txn.execute(sql, (last_id, limit,))
+            txn.execute(sql, (last_id, limit))
             return txn.fetchall()
-        return self.runInteraction(
-            "get_all_updated_caches", get_all_updated_caches_txn
-        )
+
+        return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
 
     def get_cache_stream_token(self):
         if self._cache_id_gen:
@@ -1447,8 +1397,9 @@ class SQLBaseStore(object):
         else:
             return 0
 
-    def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols,
-                                     desc="_simple_select_list_paginate"):
+    def _simple_select_list_paginate(
+        self, table, keyvalues, pagevalues, retcols, desc="_simple_select_list_paginate"
+    ):
         """Executes a SELECT query on the named table with start and limit,
         of row numbers, which may return zero or number of rows from start to limit,
         returning the result as a list of dicts.
@@ -1468,11 +1419,16 @@ class SQLBaseStore(object):
         return self.runInteraction(
             desc,
             self._simple_select_list_paginate_txn,
-            table, keyvalues, pagevalues, retcols
+            table,
+            keyvalues,
+            pagevalues,
+            retcols,
         )
 
     @classmethod
-    def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols):
+    def _simple_select_list_paginate_txn(
+        cls, txn, table, keyvalues, pagevalues, retcols
+    ):
         """Executes a SELECT query on the named table with start and limit,
         of row numbers, which may return zero or number of rows from start to limit,
         returning the result as a list of dicts.
@@ -1497,22 +1453,23 @@ class SQLBaseStore(object):
                 ", ".join(retcols),
                 table,
                 " AND ".join("%s = ?" % (k,) for k in keyvalues),
-                " ? ASC LIMIT ? OFFSET ?"
+                " ? ASC LIMIT ? OFFSET ?",
             )
             txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
         else:
             sql = "SELECT %s FROM %s ORDER BY %s" % (
                 ", ".join(retcols),
                 table,
-                " ? ASC LIMIT ? OFFSET ?"
+                " ? ASC LIMIT ? OFFSET ?",
             )
             txn.execute(sql, pagevalues)
 
         return cls.cursor_to_dict(txn)
 
     @defer.inlineCallbacks
-    def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols,
-                               desc="get_user_list_paginate"):
+    def get_user_list_paginate(
+        self, table, keyvalues, pagevalues, retcols, desc="get_user_list_paginate"
+    ):
         """Get a list of users from start row to a limit number of rows. This will
         return a json object with users and total number of users in users list.
 
@@ -1532,16 +1489,13 @@ class SQLBaseStore(object):
         users = yield self.runInteraction(
             desc,
             self._simple_select_list_paginate_txn,
-            table, keyvalues, pagevalues, retcols
-        )
-        count = yield self.runInteraction(
-            desc,
-            self.get_user_count_txn
+            table,
+            keyvalues,
+            pagevalues,
+            retcols,
         )
-        retval = {
-            "users": users,
-            "total": count
-        }
+        count = yield self.runInteraction(desc, self.get_user_count_txn)
+        retval = {"users": users, "total": count}
         defer.returnValue(retval)
 
     def get_user_count_txn(self, txn):
@@ -1556,8 +1510,9 @@ class SQLBaseStore(object):
         txn.execute(sql_count)
         return txn.fetchone()[0]
 
-    def _simple_search_list(self, table, term, col, retcols,
-                            desc="_simple_search_list"):
+    def _simple_search_list(
+        self, table, term, col, retcols, desc="_simple_search_list"
+    ):
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
@@ -1572,9 +1527,7 @@ class SQLBaseStore(object):
         """
 
         return self.runInteraction(
-            desc,
-            self._simple_search_list_txn,
-            table, term, col, retcols
+            desc, self._simple_search_list_txn, table, term, col, retcols
         )
 
     @classmethod
@@ -1593,11 +1546,7 @@ class SQLBaseStore(object):
             defer.Deferred: resolves to list[dict[str, Any]] or None
         """
         if term:
-            sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (
-                ", ".join(retcols),
-                table,
-                col
-            )
+            sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
             termvalues = ["%%" + term + "%%"]
             txn.execute(sql, termvalues)
         else:
@@ -1618,6 +1567,7 @@ class _RollbackButIsFineException(Exception):
     """ This exception is used to rollback a transaction without implying
     something went wrong.
     """
+
     pass