diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index d9d0255d0b..5a80eef211 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
import logging
import sys
import threading
@@ -26,9 +27,12 @@ from prometheus_client import Histogram
from twisted.internet import defer
from synapse.api.errors import StoreError
-from synapse.storage.engines import PostgresEngine
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import get_domain_from_id
from synapse.util.caches.descriptors import Cache
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.stringutils import exception_to_unicode
logger = logging.getLogger(__name__)
@@ -48,6 +52,25 @@ sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
+# Unique indexes which have been added in background updates. Maps from table name
+# to the name of the background update which added the unique index to that table.
+#
+# This is used by the upsert logic to figure out which tables are safe to do a proper
+# UPSERT on: until the relevant background update has completed, we
+# have to emulate an upsert by locking the table.
+#
+UNIQUE_INDEX_BACKGROUND_UPDATES = {
+ "user_ips": "user_ips_device_unique_index",
+ "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
+ "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
+ "event_search": "event_search_event_id_idx",
+}
+
+# This is a special cache name we use to batch multiple invalidations of caches
+# based on the current state when notifying workers over replication.
+_CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
+
+
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
@@ -83,6 +106,14 @@ class LoggingTransaction(object):
def __iter__(self):
return self.txn.__iter__()
+ def execute_batch(self, sql, args):
+ if isinstance(self.database_engine, PostgresEngine):
+ from psycopg2.extras import execute_batch
+ self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
+ else:
+ for val in args:
+ self.execute(sql, val)
+
def execute(self, sql, *args):
self._do_execute(self.txn.execute, sql, *args)
@@ -191,6 +222,57 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine
+ # A set of tables that are not safe to use native upserts in.
+ self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
+
+ # We add the user_directory_search table to the blacklist on SQLite
+ # because the existing search table does not have an index, making it
+ # unsafe to use native upserts.
+ if isinstance(self.database_engine, Sqlite3Engine):
+ self._unsafe_to_upsert_tables.add("user_directory_search")
+
+ if self.database_engine.can_native_upsert:
+ # Check ASAP (and then later, every 1s) to see if we have finished
+ # background updates of tables that aren't safe to update.
+ self._clock.call_later(
+ 0.0,
+ run_as_background_process,
+ "upsert_safety_check",
+ self._check_safe_to_upsert
+ )
+
+ @defer.inlineCallbacks
+ def _check_safe_to_upsert(self):
+ """
+ Is it safe to use native UPSERT?
+
+ If there are background updates, we will need to wait, as they may be
+ the addition of indexes that set the UNIQUE constraint that we require.
+
+ If the background updates have not completed, wait 15 sec and check again.
+ """
+ updates = yield self._simple_select_list(
+ "background_updates",
+ keyvalues=None,
+ retcols=["update_name"],
+ desc="check_background_updates",
+ )
+ updates = [x["update_name"] for x in updates]
+
+ for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
+ if update_name not in updates:
+ logger.debug("Now safe to upsert in %s", table)
+ self._unsafe_to_upsert_tables.discard(table)
+
+ # If there's any updates still running, reschedule to run.
+ if updates:
+ self._clock.call_later(
+ 15.0,
+ run_as_background_process,
+ "upsert_safety_check",
+ self._check_safe_to_upsert
+ )
+
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -249,32 +331,32 @@ class SQLBaseStore(object):
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
- logger.warn(
+ logger.warning(
"[TXN OPERROR] {%s} %s %d/%d",
- name, e, i, N
+ name, exception_to_unicode(e), i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
- logger.warn(
+ logger.warning(
"[TXN EROLL] {%s} %s",
- name, e1,
+ name, exception_to_unicode(e1),
)
continue
raise
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
- logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+ logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
- logger.warn(
+ logger.warning(
"[TXN EROLL] {%s} %s",
- name, e1,
+ name, exception_to_unicode(e1),
)
continue
raise
@@ -493,8 +575,15 @@ class SQLBaseStore(object):
txn.executemany(sql, vals)
@defer.inlineCallbacks
- def _simple_upsert(self, table, keyvalues, values,
- insertion_values={}, desc="_simple_upsert", lock=True):
+ def _simple_upsert(
+ self,
+ table,
+ keyvalues,
+ values,
+ insertion_values={},
+ desc="_simple_upsert",
+ lock=True
+ ):
"""
`lock` should generally be set to True (the default), but can be set
@@ -515,16 +604,21 @@ class SQLBaseStore(object):
inserting
lock (bool): True to lock the table when doing the upsert.
Returns:
- Deferred(bool): True if a new entry was created, False if an
- existing one was updated.
+ Deferred(None or bool): Native upserts always return None. Emulated
+ upserts return True if a new entry was created, False if an existing
+ one was updated.
"""
attempts = 0
while True:
try:
result = yield self.runInteraction(
desc,
- self._simple_upsert_txn, table, keyvalues, values, insertion_values,
- lock=lock
+ self._simple_upsert_txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values,
+ lock=lock,
)
defer.returnValue(result)
except self.database_engine.module.IntegrityError as e:
@@ -536,30 +630,111 @@ class SQLBaseStore(object):
# presumably we raced with another transaction: let's retry.
logger.warn(
- "IntegrityError when upserting into %s; retrying: %s",
- table, e
+ "%s when upserting into %s; retrying: %s", e.__name__, table, e
)
- def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
- lock=True):
+ def _simple_upsert_txn(
+ self,
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values={},
+ lock=True,
+ ):
+ """
+ Pick the UPSERT method which works best on the platform. Either the
+ native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
+
+ Args:
+ txn: The transaction to use.
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ None or bool: Native upserts always return None. Emulated
+ upserts return True if a new entry was created, False if an existing
+ one was updated.
+ """
+ if (
+ self.database_engine.can_native_upsert
+ and table not in self._unsafe_to_upsert_tables
+ ):
+ return self._simple_upsert_txn_native_upsert(
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values=insertion_values,
+ )
+ else:
+ return self._simple_upsert_txn_emulated(
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values=insertion_values,
+ lock=lock,
+ )
+
+ def _simple_upsert_txn_emulated(
+ self, txn, table, keyvalues, values, insertion_values={}, lock=True
+ ):
+ """
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ bool: Return True if a new entry was created, False if an existing
+ one was updated.
+ """
# We need to lock the table :(, unless we're *really* careful
if lock:
self.database_engine.lock_table(txn, table)
- # First try to update.
- sql = "UPDATE %s SET %s WHERE %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in values),
- " AND ".join("%s = ?" % (k,) for k in keyvalues)
- )
- sqlargs = list(values.values()) + list(keyvalues.values())
+ def _getwhere(key):
+ # If the value we're passing in is None (aka NULL), we need to use
+ # IS, not =, as NULL = NULL equals NULL (False).
+ if keyvalues[key] is None:
+ return "%s IS ?" % (key,)
+ else:
+ return "%s = ?" % (key,)
- txn.execute(sql, sqlargs)
- if txn.rowcount > 0:
- # successfully updated at least one row.
- return False
+ if not values:
+ # If `values` is empty, then all of the values we care about are in
+ # the unique key, so there is nothing to UPDATE. We can just do a
+ # SELECT instead to see if it exists.
+ sql = "SELECT 1 FROM %s WHERE %s" % (
+ table,
+ " AND ".join(_getwhere(k) for k in keyvalues)
+ )
+ sqlargs = list(keyvalues.values())
+ txn.execute(sql, sqlargs)
+ if txn.fetchall():
+ # We have an existing record.
+ return False
+ else:
+ # First try to update.
+ sql = "UPDATE %s SET %s WHERE %s" % (
+ table,
+ ", ".join("%s = ?" % (k,) for k in values),
+ " AND ".join(_getwhere(k) for k in keyvalues)
+ )
+ sqlargs = list(values.values()) + list(keyvalues.values())
- # We didn't update any rows so insert a new one
+ txn.execute(sql, sqlargs)
+ if txn.rowcount > 0:
+ # successfully updated at least one row.
+ return False
+
+ # We didn't find any existing rows, so insert a new one
allvalues = {}
allvalues.update(keyvalues)
allvalues.update(values)
@@ -568,12 +743,144 @@ class SQLBaseStore(object):
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in allvalues),
- ", ".join("?" for _ in allvalues)
+ ", ".join("?" for _ in allvalues),
)
txn.execute(sql, list(allvalues.values()))
# successfully inserted
return True
+ def _simple_upsert_txn_native_upsert(
+ self, txn, table, keyvalues, values, insertion_values={}
+ ):
+ """
+ Use the native UPSERT functionality in recent PostgreSQL versions.
+
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ Returns:
+ None
+ """
+ allvalues = {}
+ allvalues.update(keyvalues)
+ allvalues.update(values)
+ allvalues.update(insertion_values)
+
+ sql = (
+ "INSERT INTO %s (%s) VALUES (%s) "
+ "ON CONFLICT (%s) DO UPDATE SET %s"
+ ) % (
+ table,
+ ", ".join(k for k in allvalues),
+ ", ".join("?" for _ in allvalues),
+ ", ".join(k for k in keyvalues),
+ ", ".join(k + "=EXCLUDED." + k for k in values),
+ )
+ txn.execute(sql, list(allvalues.values()))
+
+ def _simple_upsert_many_txn(
+ self, txn, table, key_names, key_values, value_names, value_values
+ ):
+ """
+ Upsert, many times.
+
+ Args:
+ table (str): The table to upsert into
+ key_names (list[str]): The key column names.
+ key_values (list[list]): A list of each row's key column values.
+ value_names (list[str]): The value column names. If empty, no
+ values will be used, even if value_values is provided.
+ value_values (list[list]): A list of each row's value column values.
+ Returns:
+ None
+ """
+ if (
+ self.database_engine.can_native_upsert
+ and table not in self._unsafe_to_upsert_tables
+ ):
+ return self._simple_upsert_many_txn_native_upsert(
+ txn, table, key_names, key_values, value_names, value_values
+ )
+ else:
+ return self._simple_upsert_many_txn_emulated(
+ txn, table, key_names, key_values, value_names, value_values
+ )
+
+ def _simple_upsert_many_txn_emulated(
+ self, txn, table, key_names, key_values, value_names, value_values
+ ):
+ """
+ Upsert, many times, but without native UPSERT support or batching.
+
+ Args:
+ table (str): The table to upsert into
+ key_names (list[str]): The key column names.
+ key_values (list[list]): A list of each row's key column values.
+ value_names (list[str]): The value column names. If empty, no
+ values will be used, even if value_values is provided.
+ value_values (list[list]): A list of each row's value column values.
+ Returns:
+ None
+ """
+ # No value columns, therefore make a blank list so that the following
+ # zip() works correctly.
+ if not value_names:
+ value_values = [() for x in range(len(key_values))]
+
+ for keyv, valv in zip(key_values, value_values):
+ _keys = {x: y for x, y in zip(key_names, keyv)}
+ _vals = {x: y for x, y in zip(value_names, valv)}
+
+ self._simple_upsert_txn_emulated(txn, table, _keys, _vals)
+
+ def _simple_upsert_many_txn_native_upsert(
+ self, txn, table, key_names, key_values, value_names, value_values
+ ):
+ """
+ Upsert, many times, using batching where possible.
+
+ Args:
+ table (str): The table to upsert into
+ key_names (list[str]): The key column names.
+ key_values (list[list]): A list of each row's key column values.
+ value_names (list[str]): The value column names. If empty, no
+ values will be used, even if value_values is provided.
+ value_values (list[list]): A list of each row's value column values.
+ Returns:
+ None
+ """
+ allnames = []
+ allnames.extend(key_names)
+ allnames.extend(value_names)
+
+ if not value_names:
+ # No value columns, therefore make a blank list so that the
+ # following zip() works correctly.
+ latter = "NOTHING"
+ value_values = [() for x in range(len(key_values))]
+ else:
+ latter = (
+ "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in value_names)
+ )
+
+ sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
+ table,
+ ", ".join(k for k in allnames),
+ ", ".join("?" for _ in allnames),
+ ", ".join(key_names),
+ latter,
+ )
+
+ args = []
+
+ for x, y in zip(key_values, value_values):
+ args.append(tuple(x) + tuple(y))
+
+ return txn.execute_batch(sql, args)
+
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to
@@ -849,9 +1156,9 @@ class SQLBaseStore(object):
rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
if rowcount == 0:
- raise StoreError(404, "No row found")
+ raise StoreError(404, "No row found (%s)" % (table,))
if rowcount > 1:
- raise StoreError(500, "More than one row matched")
+ raise StoreError(500, "More than one row matched (%s)" % (table,))
@staticmethod
def _simple_select_one_txn(txn, table, keyvalues, retcols,
@@ -868,9 +1175,9 @@ class SQLBaseStore(object):
if not row:
if allow_none:
return None
- raise StoreError(404, "No row found")
+ raise StoreError(404, "No row found (%s)" % (table,))
if txn.rowcount > 1:
- raise StoreError(500, "More than one row matched")
+ raise StoreError(500, "More than one row matched (%s)" % (table,))
return dict(zip(retcols, row))
@@ -902,9 +1209,9 @@ class SQLBaseStore(object):
txn.execute(sql, list(keyvalues.values()))
if txn.rowcount == 0:
- raise StoreError(404, "No row found")
+ raise StoreError(404, "No row found (%s)" % (table,))
if txn.rowcount > 1:
- raise StoreError(500, "more than one row matched")
+ raise StoreError(500, "More than one row matched (%s)" % (table,))
def _simple_delete(self, table, keyvalues, desc):
return self.runInteraction(
@@ -1005,6 +1312,84 @@ class SQLBaseStore(object):
be invalidated.
"""
txn.call_after(cache_func.invalidate, keys)
+ self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
+
+ def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
+ """Special case invalidation of caches based on current state.
+
+ We special case this so that we can batch the cache invalidations into a
+ single replication poke.
+
+ Args:
+ txn
+ room_id (str): Room where state changed
+ members_changed (iterable[str]): The user_ids of members that have changed
+ """
+ txn.call_after(self._invalidate_state_caches, room_id, members_changed)
+
+ keys = itertools.chain([room_id], members_changed)
+ self._send_invalidation_to_replication(
+ txn, _CURRENT_STATE_CACHE_NAME, keys,
+ )
+
+ def _invalidate_state_caches(self, room_id, members_changed):
+ """Invalidates caches that are based on the current state, but does
+ not stream invalidations down replication.
+
+ Args:
+ room_id (str): Room where state changed
+ members_changed (iterable[str]): The user_ids of members that have
+ changed
+ """
+ for member in members_changed:
+ self._attempt_to_invalidate_cache(
+ "get_rooms_for_user_with_stream_ordering", (member,),
+ )
+
+ for host in set(get_domain_from_id(u) for u in members_changed):
+ self._attempt_to_invalidate_cache(
+ "is_host_joined", (room_id, host,),
+ )
+ self._attempt_to_invalidate_cache(
+ "was_host_joined", (room_id, host,),
+ )
+
+ self._attempt_to_invalidate_cache(
+ "get_users_in_room", (room_id,),
+ )
+ self._attempt_to_invalidate_cache(
+ "get_room_summary", (room_id,),
+ )
+ self._attempt_to_invalidate_cache(
+ "get_current_state_ids", (room_id,),
+ )
+
+ def _attempt_to_invalidate_cache(self, cache_name, key):
+ """Attempts to invalidate the cache of the given name, ignoring if the
+ cache doesn't exist. Mainly used for invalidating caches on workers,
+ where they may not have the cache.
+
+ Args:
+ cache_name (str)
+ key (tuple)
+ """
+ try:
+ getattr(self, cache_name).invalidate(key)
+ except AttributeError:
+ # We probably haven't pulled in the cache in this worker,
+ # which is fine.
+ pass
+
+ def _send_invalidation_to_replication(self, txn, cache_name, keys):
+ """Notifies replication that given cache has been invalidated.
+
+ Note that this does *not* invalidate the cache locally.
+
+ Args:
+ txn
+ cache_name (str)
+ keys (iterable[str])
+ """
if isinstance(self.database_engine, PostgresEngine):
# get_next() returns a context manager which is designed to wrap
@@ -1022,7 +1407,7 @@ class SQLBaseStore(object):
table="cache_invalidation_stream",
values={
"stream_id": stream_id,
- "cache_func": cache_func.__name__,
+ "cache_func": cache_name,
"keys": list(keys),
"invalidation_ts": self.clock.time_msec(),
}
|