summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py461
1 files changed, 423 insertions, 38 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index d9d0255d0b..5a80eef211 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import itertools
 import logging
 import sys
 import threading
@@ -26,9 +27,12 @@ from prometheus_client import Histogram
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from synapse.storage.engines import PostgresEngine
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import get_domain_from_id
 from synapse.util.caches.descriptors import Cache
 from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.stringutils import exception_to_unicode
 
 logger = logging.getLogger(__name__)
 
@@ -48,6 +52,25 @@ sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
 sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
 
 
+# Unique indexes which have been added in background updates. Maps from table name
+# to the name of the background update which added the unique index to that table.
+#
+# This is used by the upsert logic to figure out which tables are safe to do a proper
+# UPSERT on: until the relevant background update has completed, we
+# have to emulate an upsert by locking the table.
+#
+UNIQUE_INDEX_BACKGROUND_UPDATES = {
+    "user_ips": "user_ips_device_unique_index",
+    "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
+    "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
+    "event_search": "event_search_event_id_idx",
+}
+
+# This is a special cache name we use to batch multiple invalidations of caches
+# based on the current state when notifying workers over replication.
+_CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
+
+
 class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
@@ -83,6 +106,14 @@ class LoggingTransaction(object):
     def __iter__(self):
         return self.txn.__iter__()
 
+    def execute_batch(self, sql, args):
+        if isinstance(self.database_engine, PostgresEngine):
+            from psycopg2.extras import execute_batch
+            self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
+        else:
+            for val in args:
+                self.execute(sql, val)
+
     def execute(self, sql, *args):
         self._do_execute(self.txn.execute, sql, *args)
 
@@ -191,6 +222,57 @@ class SQLBaseStore(object):
 
         self.database_engine = hs.database_engine
 
+        # A set of tables that are not safe to use native upserts in.
+        self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
+
+        # We add the user_directory_search table to the blacklist on SQLite
+        # because the existing search table does not have an index, making it
+        # unsafe to use native upserts.
+        if isinstance(self.database_engine, Sqlite3Engine):
+            self._unsafe_to_upsert_tables.add("user_directory_search")
+
+        if self.database_engine.can_native_upsert:
+            # Check ASAP (and then later, every 1s) to see if we have finished
+            # background updates of tables that aren't safe to update.
+            self._clock.call_later(
+                0.0,
+                run_as_background_process,
+                "upsert_safety_check",
+                self._check_safe_to_upsert
+            )
+
+    @defer.inlineCallbacks
+    def _check_safe_to_upsert(self):
+        """
+        Is it safe to use native UPSERT?
+
+        If there are background updates, we will need to wait, as they may be
+        the addition of indexes that set the UNIQUE constraint that we require.
+
+        If the background updates have not completed, wait 15 sec and check again.
+        """
+        updates = yield self._simple_select_list(
+            "background_updates",
+            keyvalues=None,
+            retcols=["update_name"],
+            desc="check_background_updates",
+        )
+        updates = [x["update_name"] for x in updates]
+
+        for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
+            if update_name not in updates:
+                logger.debug("Now safe to upsert in %s", table)
+                self._unsafe_to_upsert_tables.discard(table)
+
+        # If there's any updates still running, reschedule to run.
+        if updates:
+            self._clock.call_later(
+                15.0,
+                run_as_background_process,
+                "upsert_safety_check",
+                self._check_safe_to_upsert
+            )
+
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
 
@@ -249,32 +331,32 @@ class SQLBaseStore(object):
                 except self.database_engine.module.OperationalError as e:
                     # This can happen if the database disappears mid
                     # transaction.
-                    logger.warn(
+                    logger.warning(
                         "[TXN OPERROR] {%s} %s %d/%d",
-                        name, e, i, N
+                        name, exception_to_unicode(e), i, N
                     )
                     if i < N:
                         i += 1
                         try:
                             conn.rollback()
                         except self.database_engine.module.Error as e1:
-                            logger.warn(
+                            logger.warning(
                                 "[TXN EROLL] {%s} %s",
-                                name, e1,
+                                name, exception_to_unicode(e1),
                             )
                         continue
                     raise
                 except self.database_engine.module.DatabaseError as e:
                     if self.database_engine.is_deadlock(e):
-                        logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+                        logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
                         if i < N:
                             i += 1
                             try:
                                 conn.rollback()
                             except self.database_engine.module.Error as e1:
-                                logger.warn(
+                                logger.warning(
                                     "[TXN EROLL] {%s} %s",
-                                    name, e1,
+                                    name, exception_to_unicode(e1),
                                 )
                             continue
                     raise
@@ -493,8 +575,15 @@ class SQLBaseStore(object):
         txn.executemany(sql, vals)
 
     @defer.inlineCallbacks
-    def _simple_upsert(self, table, keyvalues, values,
-                       insertion_values={}, desc="_simple_upsert", lock=True):
+    def _simple_upsert(
+        self,
+        table,
+        keyvalues,
+        values,
+        insertion_values={},
+        desc="_simple_upsert",
+        lock=True
+    ):
         """
 
         `lock` should generally be set to True (the default), but can be set
@@ -515,16 +604,21 @@ class SQLBaseStore(object):
                 inserting
             lock (bool): True to lock the table when doing the upsert.
         Returns:
-            Deferred(bool): True if a new entry was created, False if an
-                existing one was updated.
+            Deferred(None or bool): Native upserts always return None. Emulated
+            upserts return True if a new entry was created, False if an existing
+            one was updated.
         """
         attempts = 0
         while True:
             try:
                 result = yield self.runInteraction(
                     desc,
-                    self._simple_upsert_txn, table, keyvalues, values, insertion_values,
-                    lock=lock
+                    self._simple_upsert_txn,
+                    table,
+                    keyvalues,
+                    values,
+                    insertion_values,
+                    lock=lock,
                 )
                 defer.returnValue(result)
             except self.database_engine.module.IntegrityError as e:
@@ -536,30 +630,111 @@ class SQLBaseStore(object):
 
                 # presumably we raced with another transaction: let's retry.
                 logger.warn(
-                    "IntegrityError when upserting into %s; retrying: %s",
-                    table, e
+                    "%s when upserting into %s; retrying: %s", e.__name__, table, e
                 )
 
-    def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
-                           lock=True):
+    def _simple_upsert_txn(
+        self,
+        txn,
+        table,
+        keyvalues,
+        values,
+        insertion_values={},
+        lock=True,
+    ):
+        """
+        Pick the UPSERT method which works best on the platform. Either the
+        native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
+
+        Args:
+            txn: The transaction to use.
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
+        Returns:
+            None or bool: Native upserts always return None. Emulated
+            upserts return True if a new entry was created, False if an existing
+            one was updated.
+        """
+        if (
+            self.database_engine.can_native_upsert
+            and table not in self._unsafe_to_upsert_tables
+        ):
+            return self._simple_upsert_txn_native_upsert(
+                txn,
+                table,
+                keyvalues,
+                values,
+                insertion_values=insertion_values,
+            )
+        else:
+            return self._simple_upsert_txn_emulated(
+                txn,
+                table,
+                keyvalues,
+                values,
+                insertion_values=insertion_values,
+                lock=lock,
+            )
+
+    def _simple_upsert_txn_emulated(
+        self, txn, table, keyvalues, values, insertion_values={}, lock=True
+    ):
+        """
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
+        Returns:
+            bool: Return True if a new entry was created, False if an existing
+            one was updated.
+        """
         # We need to lock the table :(, unless we're *really* careful
         if lock:
             self.database_engine.lock_table(txn, table)
 
-        # First try to update.
-        sql = "UPDATE %s SET %s WHERE %s" % (
-            table,
-            ", ".join("%s = ?" % (k,) for k in values),
-            " AND ".join("%s = ?" % (k,) for k in keyvalues)
-        )
-        sqlargs = list(values.values()) + list(keyvalues.values())
+        def _getwhere(key):
+            # If the value we're passing in is None (aka NULL), we need to use
+            # IS, not =, as NULL = NULL equals NULL (False).
+            if keyvalues[key] is None:
+                return "%s IS ?" % (key,)
+            else:
+                return "%s = ?" % (key,)
 
-        txn.execute(sql, sqlargs)
-        if txn.rowcount > 0:
-            # successfully updated at least one row.
-            return False
+        if not values:
+            # If `values` is empty, then all of the values we care about are in
+            # the unique key, so there is nothing to UPDATE. We can just do a
+            # SELECT instead to see if it exists.
+            sql = "SELECT 1 FROM %s WHERE %s" % (
+                table,
+                " AND ".join(_getwhere(k) for k in keyvalues)
+            )
+            sqlargs = list(keyvalues.values())
+            txn.execute(sql, sqlargs)
+            if txn.fetchall():
+                # We have an existing record.
+                return False
+        else:
+            # First try to update.
+            sql = "UPDATE %s SET %s WHERE %s" % (
+                table,
+                ", ".join("%s = ?" % (k,) for k in values),
+                " AND ".join(_getwhere(k) for k in keyvalues)
+            )
+            sqlargs = list(values.values()) + list(keyvalues.values())
 
-        # We didn't update any rows so insert a new one
+            txn.execute(sql, sqlargs)
+            if txn.rowcount > 0:
+                # successfully updated at least one row.
+                return False
+
+        # We didn't find any existing rows, so insert a new one
         allvalues = {}
         allvalues.update(keyvalues)
         allvalues.update(values)
@@ -568,12 +743,144 @@ class SQLBaseStore(object):
         sql = "INSERT INTO %s (%s) VALUES (%s)" % (
             table,
             ", ".join(k for k in allvalues),
-            ", ".join("?" for _ in allvalues)
+            ", ".join("?" for _ in allvalues),
         )
         txn.execute(sql, list(allvalues.values()))
         # successfully inserted
         return True
 
+    def _simple_upsert_txn_native_upsert(
+        self, txn, table, keyvalues, values, insertion_values={}
+    ):
+        """
+        Use the native UPSERT functionality in recent PostgreSQL versions.
+
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+        Returns:
+            None
+        """
+        allvalues = {}
+        allvalues.update(keyvalues)
+        allvalues.update(values)
+        allvalues.update(insertion_values)
+
+        sql = (
+            "INSERT INTO %s (%s) VALUES (%s) "
+            "ON CONFLICT (%s) DO UPDATE SET %s"
+        ) % (
+            table,
+            ", ".join(k for k in allvalues),
+            ", ".join("?" for _ in allvalues),
+            ", ".join(k for k in keyvalues),
+            ", ".join(k + "=EXCLUDED." + k for k in values),
+        )
+        txn.execute(sql, list(allvalues.values()))
+
+    def _simple_upsert_many_txn(
+        self, txn, table, key_names, key_values, value_names, value_values
+    ):
+        """
+        Upsert, many times.
+
+        Args:
+            table (str): The table to upsert into
+            key_names (list[str]): The key column names.
+            key_values (list[list]): A list of each row's key column values.
+            value_names (list[str]): The value column names. If empty, no
+                values will be used, even if value_values is provided.
+            value_values (list[list]): A list of each row's value column values.
+        Returns:
+            None
+        """
+        if (
+            self.database_engine.can_native_upsert
+            and table not in self._unsafe_to_upsert_tables
+        ):
+            return self._simple_upsert_many_txn_native_upsert(
+                txn, table, key_names, key_values, value_names, value_values
+            )
+        else:
+            return self._simple_upsert_many_txn_emulated(
+                txn, table, key_names, key_values, value_names, value_values
+            )
+
+    def _simple_upsert_many_txn_emulated(
+        self, txn, table, key_names, key_values, value_names, value_values
+    ):
+        """
+        Upsert, many times, but without native UPSERT support or batching.
+
+        Args:
+            table (str): The table to upsert into
+            key_names (list[str]): The key column names.
+            key_values (list[list]): A list of each row's key column values.
+            value_names (list[str]): The value column names. If empty, no
+                values will be used, even if value_values is provided.
+            value_values (list[list]): A list of each row's value column values.
+        Returns:
+            None
+        """
+        # No value columns, therefore make a blank list so that the following
+        # zip() works correctly.
+        if not value_names:
+            value_values = [() for x in range(len(key_values))]
+
+        for keyv, valv in zip(key_values, value_values):
+            _keys = {x: y for x, y in zip(key_names, keyv)}
+            _vals = {x: y for x, y in zip(value_names, valv)}
+
+            self._simple_upsert_txn_emulated(txn, table, _keys, _vals)
+
+    def _simple_upsert_many_txn_native_upsert(
+        self, txn, table, key_names, key_values, value_names, value_values
+    ):
+        """
+        Upsert, many times, using batching where possible.
+
+        Args:
+            table (str): The table to upsert into
+            key_names (list[str]): The key column names.
+            key_values (list[list]): A list of each row's key column values.
+            value_names (list[str]): The value column names. If empty, no
+                values will be used, even if value_values is provided.
+            value_values (list[list]): A list of each row's value column values.
+        Returns:
+            None
+        """
+        allnames = []
+        allnames.extend(key_names)
+        allnames.extend(value_names)
+
+        if not value_names:
+            # No value columns, therefore make a blank list so that the
+            # following zip() works correctly.
+            latter = "NOTHING"
+            value_values = [() for x in range(len(key_values))]
+        else:
+            latter = (
+                "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in value_names)
+            )
+
+        sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
+            table,
+            ", ".join(k for k in allnames),
+            ", ".join("?" for _ in allnames),
+            ", ".join(key_names),
+            latter,
+        )
+
+        args = []
+
+        for x, y in zip(key_values, value_values):
+            args.append(tuple(x) + tuple(y))
+
+        return txn.execute_batch(sql, args)
+
     def _simple_select_one(self, table, keyvalues, retcols,
                            allow_none=False, desc="_simple_select_one"):
         """Executes a SELECT query on the named table, which is expected to
@@ -849,9 +1156,9 @@ class SQLBaseStore(object):
         rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
 
         if rowcount == 0:
-            raise StoreError(404, "No row found")
+            raise StoreError(404, "No row found (%s)" % (table,))
         if rowcount > 1:
-            raise StoreError(500, "More than one row matched")
+            raise StoreError(500, "More than one row matched (%s)" % (table,))
 
     @staticmethod
     def _simple_select_one_txn(txn, table, keyvalues, retcols,
@@ -868,9 +1175,9 @@ class SQLBaseStore(object):
         if not row:
             if allow_none:
                 return None
-            raise StoreError(404, "No row found")
+            raise StoreError(404, "No row found (%s)" % (table,))
         if txn.rowcount > 1:
-            raise StoreError(500, "More than one row matched")
+            raise StoreError(500, "More than one row matched (%s)" % (table,))
 
         return dict(zip(retcols, row))
 
@@ -902,9 +1209,9 @@ class SQLBaseStore(object):
 
         txn.execute(sql, list(keyvalues.values()))
         if txn.rowcount == 0:
-            raise StoreError(404, "No row found")
+            raise StoreError(404, "No row found (%s)" % (table,))
         if txn.rowcount > 1:
-            raise StoreError(500, "more than one row matched")
+            raise StoreError(500, "More than one row matched (%s)" % (table,))
 
     def _simple_delete(self, table, keyvalues, desc):
         return self.runInteraction(
@@ -1005,6 +1312,84 @@ class SQLBaseStore(object):
         be invalidated.
         """
         txn.call_after(cache_func.invalidate, keys)
+        self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
+
+    def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
+        """Special case invalidation of caches based on current state.
+
+        We special case this so that we can batch the cache invalidations into a
+        single replication poke.
+
+        Args:
+            txn
+            room_id (str): Room where state changed
+            members_changed (iterable[str]): The user_ids of members that have changed
+        """
+        txn.call_after(self._invalidate_state_caches, room_id, members_changed)
+
+        keys = itertools.chain([room_id], members_changed)
+        self._send_invalidation_to_replication(
+            txn, _CURRENT_STATE_CACHE_NAME, keys,
+        )
+
+    def _invalidate_state_caches(self, room_id, members_changed):
+        """Invalidates caches that are based on the current state, but does
+        not stream invalidations down replication.
+
+        Args:
+            room_id (str): Room where state changed
+            members_changed (iterable[str]): The user_ids of members that have
+                changed
+        """
+        for member in members_changed:
+            self._attempt_to_invalidate_cache(
+                "get_rooms_for_user_with_stream_ordering", (member,),
+            )
+
+        for host in set(get_domain_from_id(u) for u in members_changed):
+            self._attempt_to_invalidate_cache(
+                "is_host_joined", (room_id, host,),
+            )
+            self._attempt_to_invalidate_cache(
+                "was_host_joined", (room_id, host,),
+            )
+
+        self._attempt_to_invalidate_cache(
+            "get_users_in_room", (room_id,),
+        )
+        self._attempt_to_invalidate_cache(
+            "get_room_summary", (room_id,),
+        )
+        self._attempt_to_invalidate_cache(
+            "get_current_state_ids", (room_id,),
+        )
+
+    def _attempt_to_invalidate_cache(self, cache_name, key):
+        """Attempts to invalidate the cache of the given name, ignoring if the
+        cache doesn't exist. Mainly used for invalidating caches on workers,
+        where they may not have the cache.
+
+        Args:
+            cache_name (str)
+            key (tuple)
+        """
+        try:
+            getattr(self, cache_name).invalidate(key)
+        except AttributeError:
+            # We probably haven't pulled in the cache in this worker,
+            # which is fine.
+            pass
+
+    def _send_invalidation_to_replication(self, txn, cache_name, keys):
+        """Notifies replication that given cache has been invalidated.
+
+        Note that this does *not* invalidate the cache locally.
+
+        Args:
+            txn
+            cache_name (str)
+            keys (iterable[str])
+        """
 
         if isinstance(self.database_engine, PostgresEngine):
             # get_next() returns a context manager which is designed to wrap
@@ -1022,7 +1407,7 @@ class SQLBaseStore(object):
                 table="cache_invalidation_stream",
                 values={
                     "stream_id": stream_id,
-                    "cache_func": cache_func.__name__,
+                    "cache_func": cache_name,
                     "keys": list(keys),
                     "invalidation_ts": self.clock.time_msec(),
                 }