diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 24329879e5..42cd3c83ad 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -317,7 +317,7 @@ class DataStore(RoomMemberStore, RoomStore,
thirty_days_ago_in_secs))
for row in txn:
- if row[0] is 'unknown':
+ if row[0] == 'unknown':
pass
results[row[0]] = row[1]
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 865b5e915a..e124161845 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -26,7 +26,8 @@ from prometheus_client import Histogram
from twisted.internet import defer
from synapse.api.errors import StoreError
-from synapse.storage.engines import PostgresEngine
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.util.caches.descriptors import Cache
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.stringutils import exception_to_unicode
@@ -49,6 +50,21 @@ sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
+# Unique indexes which have been added in background updates. Maps from table name
+# to the name of the background update which added the unique index to that table.
+#
+# This is used by the upsert logic to figure out which tables are safe to do a proper
+# UPSERT on: until the relevant background update has completed, we
+# have to emulate an upsert by locking the table.
+#
+UNIQUE_INDEX_BACKGROUND_UPDATES = {
+ "user_ips": "user_ips_device_unique_index",
+ "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
+ "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
+ "event_search": "event_search_event_id_idx",
+}
+
+
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
@@ -192,6 +208,57 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine
+ # A set of tables that are not safe to use native upserts in.
+ self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
+
+ # We add the user_directory_search table to the blacklist on SQLite
+ # because the existing search table does not have an index, making it
+ # unsafe to use native upserts.
+ if isinstance(self.database_engine, Sqlite3Engine):
+ self._unsafe_to_upsert_tables.add("user_directory_search")
+
+ if self.database_engine.can_native_upsert:
+ # Check ASAP (and then later, every 1s) to see if we have finished
+ # background updates of tables that aren't safe to update.
+ self._clock.call_later(
+ 0.0,
+ run_as_background_process,
+ "upsert_safety_check",
+ self._check_safe_to_upsert
+ )
+
+ @defer.inlineCallbacks
+ def _check_safe_to_upsert(self):
+ """
+ Is it safe to use native UPSERT?
+
+ If there are background updates, we will need to wait, as they may be
+ the addition of indexes that set the UNIQUE constraint that we require.
+
+ If the background updates have not completed, wait 15 sec and check again.
+ """
+ updates = yield self._simple_select_list(
+ "background_updates",
+ keyvalues=None,
+ retcols=["update_name"],
+ desc="check_background_updates",
+ )
+ updates = [x["update_name"] for x in updates]
+
+ for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
+ if update_name not in updates:
+ logger.debug("Now safe to upsert in %s", table)
+ self._unsafe_to_upsert_tables.discard(table)
+
+ # If there's any updates still running, reschedule to run.
+ if updates:
+ self._clock.call_later(
+ 15.0,
+ run_as_background_process,
+ "upsert_safety_check",
+ self._check_safe_to_upsert
+ )
+
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -494,8 +561,15 @@ class SQLBaseStore(object):
txn.executemany(sql, vals)
@defer.inlineCallbacks
- def _simple_upsert(self, table, keyvalues, values,
- insertion_values={}, desc="_simple_upsert", lock=True):
+ def _simple_upsert(
+ self,
+ table,
+ keyvalues,
+ values,
+ insertion_values={},
+ desc="_simple_upsert",
+ lock=True
+ ):
"""
`lock` should generally be set to True (the default), but can be set
@@ -516,16 +590,21 @@ class SQLBaseStore(object):
inserting
lock (bool): True to lock the table when doing the upsert.
Returns:
- Deferred(bool): True if a new entry was created, False if an
- existing one was updated.
+ Deferred(None or bool): Native upserts always return None. Emulated
+ upserts return True if a new entry was created, False if an existing
+ one was updated.
"""
attempts = 0
while True:
try:
result = yield self.runInteraction(
desc,
- self._simple_upsert_txn, table, keyvalues, values, insertion_values,
- lock=lock
+ self._simple_upsert_txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values,
+ lock=lock,
)
defer.returnValue(result)
except self.database_engine.module.IntegrityError as e:
@@ -537,12 +616,71 @@ class SQLBaseStore(object):
# presumably we raced with another transaction: let's retry.
logger.warn(
- "IntegrityError when upserting into %s; retrying: %s",
- table, e
+ "%s when upserting into %s; retrying: %s", e.__name__, table, e
)
- def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
- lock=True):
+ def _simple_upsert_txn(
+ self,
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values={},
+ lock=True,
+ ):
+ """
+ Pick the UPSERT method which works best on the platform. Either the
+ native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
+
+ Args:
+ txn: The transaction to use.
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ None or bool: Native upserts always return None. Emulated
+ upserts return True if a new entry was created, False if an existing
+ one was updated.
+ """
+ if (
+ self.database_engine.can_native_upsert
+ and table not in self._unsafe_to_upsert_tables
+ ):
+ return self._simple_upsert_txn_native_upsert(
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values=insertion_values,
+ )
+ else:
+ return self._simple_upsert_txn_emulated(
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values=insertion_values,
+ lock=lock,
+ )
+
+ def _simple_upsert_txn_emulated(
+ self, txn, table, keyvalues, values, insertion_values={}, lock=True
+ ):
+ """
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ bool: Return True if a new entry was created, False if an existing
+ one was updated.
+ """
# We need to lock the table :(, unless we're *really* careful
if lock:
self.database_engine.lock_table(txn, table)
@@ -577,12 +715,44 @@ class SQLBaseStore(object):
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in allvalues),
- ", ".join("?" for _ in allvalues)
+ ", ".join("?" for _ in allvalues),
)
txn.execute(sql, list(allvalues.values()))
# successfully inserted
return True
+ def _simple_upsert_txn_native_upsert(
+ self, txn, table, keyvalues, values, insertion_values={}
+ ):
+ """
+ Use the native UPSERT functionality in recent PostgreSQL versions.
+
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ Returns:
+ None
+ """
+ allvalues = {}
+ allvalues.update(keyvalues)
+ allvalues.update(values)
+ allvalues.update(insertion_values)
+
+ sql = (
+ "INSERT INTO %s (%s) VALUES (%s) "
+ "ON CONFLICT (%s) DO UPDATE SET %s"
+ ) % (
+ table,
+ ", ".join(k for k in allvalues),
+ ", ".join("?" for _ in allvalues),
+ ", ".join(k for k in keyvalues),
+ ", ".join(k + "=EXCLUDED." + k for k in values),
+ )
+ txn.execute(sql, list(allvalues.values()))
+
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 5fe1ca2de7..60cdc884e6 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -240,7 +240,7 @@ class BackgroundUpdateStore(SQLBaseStore):
* An integer count of the number of items to update in this batch.
The handler should return a deferred integer count of items updated.
- The hander is responsible for updating the progress of the update.
+ The handler is responsible for updating the progress of the update.
Args:
update_name(str): The name of the update that this code handles.
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index b228a20ac2..9c21362226 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -66,6 +66,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
)
self.register_background_update_handler(
+ "user_ips_analyze",
+ self._analyze_user_ip,
+ )
+
+ self.register_background_update_handler(
"user_ips_remove_dupes",
self._remove_user_ip_dupes,
)
@@ -109,6 +114,25 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
defer.returnValue(1)
@defer.inlineCallbacks
+ def _analyze_user_ip(self, progress, batch_size):
+ # Background update to analyze user_ips table before we run the
+ # deduplication background update. The table may not have been analyzed
+ # for ages due to the table locks.
+ #
+ # This will lock out the naive upserts to user_ips while it happens, but
+ # the analyze should be quick (28GB table takes ~10s)
+ def user_ips_analyze(txn):
+ txn.execute("ANALYZE user_ips")
+
+ yield self.runInteraction(
+ "user_ips_analyze", user_ips_analyze
+ )
+
+ yield self._end_background_update("user_ips_analyze")
+
+ defer.returnValue(1)
+
+ @defer.inlineCallbacks
def _remove_user_ip_dupes(self, progress, batch_size):
# This works function works by scanning the user_ips table in batches
# based on `last_seen`. For each row in a batch it searches the rest of
@@ -167,12 +191,16 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
clause = "? <= last_seen AND last_seen < ?"
args = (begin_last_seen, end_last_seen)
+ # (Note: The DISTINCT in the inner query is important to ensure that
+ # the COUNT(*) is accurate, otherwise double counting may happen due
+ # to the join effectively being a cross product)
txn.execute(
"""
SELECT user_id, access_token, ip,
- MAX(device_id), MAX(user_agent), MAX(last_seen)
+ MAX(device_id), MAX(user_agent), MAX(last_seen),
+ COUNT(*)
FROM (
- SELECT user_id, access_token, ip
+ SELECT DISTINCT user_id, access_token, ip
FROM user_ips
WHERE {}
) c
@@ -186,7 +214,60 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
# We've got some duplicates
for i in res:
- user_id, access_token, ip, device_id, user_agent, last_seen = i
+ user_id, access_token, ip, device_id, user_agent, last_seen, count = i
+
+ # We want to delete the duplicates so we end up with only a
+ # single row.
+ #
+ # The naive way of doing this would be just to delete all rows
+ # and reinsert a constructed row. However, if there are a lot of
+ # duplicate rows this can cause the table to grow a lot, which
+ # can be problematic in two ways:
+ # 1. If user_ips is already large then this can cause the
+ # table to rapidly grow, potentially filling the disk.
+ # 2. Reinserting a lot of rows can confuse the table
+ # statistics for postgres, causing it to not use the
+ # correct indices for the query above, resulting in a full
+ # table scan. This is incredibly slow for large tables and
+ # can kill database performance. (This seems to mainly
+ # happen for the last query where the clause is simply `? <
+ # last_seen`)
+ #
+ # So instead we want to delete all but *one* of the duplicate
+ # rows. That is hard to do reliably, so we cheat and do a two
+ # step process:
+ # 1. Delete all rows with a last_seen strictly less than the
+ # max last_seen. This hopefully results in deleting all but
+ # one row the majority of the time, but there may be
+ # duplicate last_seen
+ # 2. If multiple rows remain, we fall back to the naive method
+ # and simply delete all rows and reinsert.
+ #
+ # Note that this relies on no new duplicate rows being inserted,
+ # but if that is happening then this entire process is futile
+ # anyway.
+
+ # Do step 1:
+
+ txn.execute(
+ """
+ DELETE FROM user_ips
+ WHERE user_id = ? AND access_token = ? AND ip = ? AND last_seen < ?
+ """,
+ (user_id, access_token, ip, last_seen)
+ )
+ if txn.rowcount == count - 1:
+ # We deleted all but one of the duplicate rows, i.e. there
+ # is exactly one remaining and so there is nothing left to
+ # do.
+ continue
+ elif txn.rowcount >= count:
+ raise Exception(
+ "We deleted more duplicate rows from 'user_ips' than expected",
+ )
+
+ # The previous step didn't delete enough rows, so we fallback to
+ # step 2:
# Drop all the duplicates
txn.execute(
@@ -257,7 +338,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
)
def _update_client_ips_batch_txn(self, txn, to_update):
- self.database_engine.lock_table(txn, "user_ips")
+ if "user_ips" in self._unsafe_to_upsert_tables or (
+ not self.database_engine.can_native_upsert
+ ):
+ self.database_engine.lock_table(txn, "user_ips")
for entry in iteritems(to_update):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py
index 45cebe61d1..9a3aec759e 100644
--- a/synapse/storage/e2e_room_keys.py
+++ b/synapse/storage/e2e_room_keys.py
@@ -298,6 +298,27 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
)
+ def update_e2e_room_keys_version(self, user_id, version, info):
+ """Update a given backup version
+
+ Args:
+ user_id(str): the user whose backup version we're updating
+ version(str): the version ID of the backup version we're updating
+ info(dict): the new backup version info to store
+ """
+
+ return self._simple_update(
+ table="e2e_room_keys_versions",
+ keyvalues={
+ "user_id": user_id,
+ "version": version,
+ },
+ updatevalues={
+ "auth_data": json.dumps(info["auth_data"]),
+ },
+ desc="update_e2e_room_keys_version"
+ )
+
def delete_e2e_room_keys_version(self, user_id, version=None):
"""Delete a given backup version of the user's room keys.
Doesn't delete their actual key data.
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index e2f9de8451..ff5ef97ca8 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -18,7 +18,7 @@ import platform
from ._base import IncorrectDatabaseSetup
from .postgres import PostgresEngine
-from .sqlite3 import Sqlite3Engine
+from .sqlite import Sqlite3Engine
SUPPORTED_MODULE = {
"sqlite3": Sqlite3Engine,
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 42225f8a2a..4004427c7b 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -38,6 +38,13 @@ class PostgresEngine(object):
return sql.replace("?", "%s")
def on_new_connection(self, db_conn):
+
+ # Get the version of PostgreSQL that we're using. As per the psycopg2
+ # docs: The number is formed by converting the major, minor, and
+ # revision numbers into two-decimal-digit numbers and appending them
+ # together. For example, version 8.1.5 will be returned as 80105
+ self._version = db_conn.server_version
+
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
@@ -54,6 +61,13 @@ class PostgresEngine(object):
cursor.close()
+ @property
+ def can_native_upsert(self):
+ """
+ Can we use native UPSERTs? This requires PostgreSQL 9.5+.
+ """
+ return self._version >= 90500
+
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite.py
index 19949fc474..059ab81055 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite.py
@@ -30,6 +30,14 @@ class Sqlite3Engine(object):
self._current_state_group_id = None
self._current_state_group_id_lock = threading.Lock()
+ @property
+ def can_native_upsert(self):
+ """
+ Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
+ more work we haven't done yet to tell what was inserted vs updated.
+ """
+ return self.module.sqlite_version_info >= (3, 24, 0)
+
def check_database(self, txn):
pass
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index d3b9dea1d6..38809ed0fc 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -125,6 +125,29 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
return dict(txn)
+ @defer.inlineCallbacks
+ def get_max_depth_of(self, event_ids):
+ """Returns the max depth of a set of event IDs
+
+ Args:
+ event_ids (list[str])
+
+ Returns
+ Deferred[int]
+ """
+ rows = yield self._simple_select_many_batch(
+ table="events",
+ column="event_id",
+ iterable=event_ids,
+ retcols=("depth",),
+ desc="get_max_depth_of",
+ )
+
+ if not rows:
+ defer.returnValue(0)
+ else:
+ defer.returnValue(max(row["depth"] for row in rows))
+
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn(
txn,
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 79e0276de6..81b250480d 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -904,106 +904,106 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
for room_id, current_state_tuple in iteritems(state_delta_by_room):
- to_delete, to_insert = current_state_tuple
-
- # First we add entries to the current_state_delta_stream. We
- # do this before updating the current_state_events table so
- # that we can use it to calculate the `prev_event_id`. (This
- # allows us to not have to pull out the existing state
- # unnecessarily).
- sql = """
- INSERT INTO current_state_delta_stream
- (stream_id, room_id, type, state_key, event_id, prev_event_id)
- SELECT ?, ?, ?, ?, ?, (
- SELECT event_id FROM current_state_events
- WHERE room_id = ? AND type = ? AND state_key = ?
- )
- """
- txn.executemany(sql, (
- (
- max_stream_order, room_id, etype, state_key, None,
- room_id, etype, state_key,
- )
- for etype, state_key in to_delete
- # We sanity check that we're deleting rather than updating
- if (etype, state_key) not in to_insert
- ))
- txn.executemany(sql, (
- (
- max_stream_order, room_id, etype, state_key, ev_id,
- room_id, etype, state_key,
- )
- for (etype, state_key), ev_id in iteritems(to_insert)
- ))
+ to_delete, to_insert = current_state_tuple
- # Now we actually update the current_state_events table
-
- txn.executemany(
- "DELETE FROM current_state_events"
- " WHERE room_id = ? AND type = ? AND state_key = ?",
- (
- (room_id, etype, state_key)
- for etype, state_key in itertools.chain(to_delete, to_insert)
- ),
+ # First we add entries to the current_state_delta_stream. We
+ # do this before updating the current_state_events table so
+ # that we can use it to calculate the `prev_event_id`. (This
+ # allows us to not have to pull out the existing state
+ # unnecessarily).
+ sql = """
+ INSERT INTO current_state_delta_stream
+ (stream_id, room_id, type, state_key, event_id, prev_event_id)
+ SELECT ?, ?, ?, ?, ?, (
+ SELECT event_id FROM current_state_events
+ WHERE room_id = ? AND type = ? AND state_key = ?
)
-
- self._simple_insert_many_txn(
- txn,
- table="current_state_events",
- values=[
- {
- "event_id": ev_id,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- }
- for key, ev_id in iteritems(to_insert)
- ],
+ """
+ txn.executemany(sql, (
+ (
+ max_stream_order, room_id, etype, state_key, None,
+ room_id, etype, state_key,
)
-
- txn.call_after(
- self._curr_state_delta_stream_cache.entity_has_changed,
- room_id, max_stream_order,
+ for etype, state_key in to_delete
+ # We sanity check that we're deleting rather than updating
+ if (etype, state_key) not in to_insert
+ ))
+ txn.executemany(sql, (
+ (
+ max_stream_order, room_id, etype, state_key, ev_id,
+ room_id, etype, state_key,
)
+ for (etype, state_key), ev_id in iteritems(to_insert)
+ ))
- # Invalidate the various caches
-
- # Figure out the changes of membership to invalidate the
- # `get_rooms_for_user` cache.
- # We find out which membership events we may have deleted
- # and which we have added, then we invlidate the caches for all
- # those users.
- members_changed = set(
- state_key
- for ev_type, state_key in itertools.chain(to_delete, to_insert)
- if ev_type == EventTypes.Member
- )
+ # Now we actually update the current_state_events table
- for member in members_changed:
- self._invalidate_cache_and_stream(
- txn, self.get_rooms_for_user_with_stream_ordering, (member,)
- )
+ txn.executemany(
+ "DELETE FROM current_state_events"
+ " WHERE room_id = ? AND type = ? AND state_key = ?",
+ (
+ (room_id, etype, state_key)
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ ),
+ )
- for host in set(get_domain_from_id(u) for u in members_changed):
- self._invalidate_cache_and_stream(
- txn, self.is_host_joined, (room_id, host)
- )
- self._invalidate_cache_and_stream(
- txn, self.was_host_joined, (room_id, host)
- )
+ self._simple_insert_many_txn(
+ txn,
+ table="current_state_events",
+ values=[
+ {
+ "event_id": ev_id,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ }
+ for key, ev_id in iteritems(to_insert)
+ ],
+ )
+
+ txn.call_after(
+ self._curr_state_delta_stream_cache.entity_has_changed,
+ room_id, max_stream_order,
+ )
+
+ # Invalidate the various caches
+
+ # Figure out the changes of membership to invalidate the
+ # `get_rooms_for_user` cache.
+ # We find out which membership events we may have deleted
+ # and which we have added, then we invlidate the caches for all
+ # those users.
+ members_changed = set(
+ state_key
+ for ev_type, state_key in itertools.chain(to_delete, to_insert)
+ if ev_type == EventTypes.Member
+ )
+ for member in members_changed:
self._invalidate_cache_and_stream(
- txn, self.get_users_in_room, (room_id,)
+ txn, self.get_rooms_for_user_with_stream_ordering, (member,)
)
+ for host in set(get_domain_from_id(u) for u in members_changed):
self._invalidate_cache_and_stream(
- txn, self.get_room_summary, (room_id,)
+ txn, self.is_host_joined, (room_id, host)
)
-
self._invalidate_cache_and_stream(
- txn, self.get_current_state_ids, (room_id,)
+ txn, self.was_host_joined, (room_id, host)
)
+ self._invalidate_cache_and_stream(
+ txn, self.get_users_in_room, (room_id,)
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_room_summary, (room_id,)
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_current_state_ids, (room_id,)
+ )
+
def _update_forward_extremities_txn(self, txn, new_forward_extremities,
max_stream_order):
for room_id, new_extrem in iteritems(new_forward_extremities):
@@ -1268,6 +1268,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
event.internal_metadata.get_dict()
),
"json": encode_json(event_dict(event)),
+ "format_version": event.format_version,
}
for event, _ in events_and_contexts
],
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index a8326f5296..1716be529a 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -21,13 +21,14 @@ from canonicaljson import json
from twisted.internet import defer
+from synapse.api.constants import EventFormatVersions, EventTypes
from synapse.api.errors import NotFoundError
+from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
# these are only included to make the type annotations work
-from synapse.events import EventBase # noqa: F401
-from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import get_domain_from_id
from synapse.util.logcontext import (
LoggingContext,
PreserveLoggingContext,
@@ -160,9 +161,14 @@ class EventsWorkerStore(SQLBaseStore):
log_ctx = LoggingContext.current_context()
log_ctx.record_event_fetch(len(missing_events_ids))
+ # Note that _enqueue_events is also responsible for turning db rows
+ # into FrozenEvents (via _get_event_from_row), which involves seeing if
+ # the events have been redacted, and if so pulling the redaction event out
+ # of the database to check it.
+ #
+ # _enqueue_events is a bit of a rubbish name but naming is hard.
missing_events = yield self._enqueue_events(
missing_events_ids,
- check_redacted=check_redacted,
allow_rejected=allow_rejected,
)
@@ -174,6 +180,50 @@ class EventsWorkerStore(SQLBaseStore):
if not entry:
continue
+ # Starting in room version v3, some redactions need to be rechecked if we
+ # didn't have the redacted event at the time, so we recheck on read
+ # instead.
+ if not allow_rejected and entry.event.type == EventTypes.Redaction:
+ if entry.event.internal_metadata.need_to_check_redaction():
+ # XXX: we need to avoid calling get_event here.
+ #
+ # The problem is that we end up at this point when an event
+ # which has been redacted is pulled out of the database by
+ # _enqueue_events, because _enqueue_events needs to check the
+ # redaction before it can cache the redacted event. So obviously,
+ # calling get_event to get the redacted event out of the database
+ # gives us an infinite loop.
+ #
+ # For now (quick hack to fix during 0.99 release cycle), we just
+ # go and fetch the relevant row from the db, but it would be nice
+ # to think about how we can cache this rather than hit the db
+ # every time we access a redaction event.
+ #
+ # One thought on how to do this:
+ # 1. split _get_events up so that it is divided into (a) get the
+ # rawish event from the db/cache, (b) do the redaction/rejection
+ # filtering
+ # 2. have _get_event_from_row just call the first half of that
+
+ orig_sender = yield self._simple_select_one_onecol(
+ table="events",
+ keyvalues={"event_id": entry.event.redacts},
+ retcol="sender",
+ allow_none=True,
+ )
+
+ expected_domain = get_domain_from_id(entry.event.sender)
+ if orig_sender and get_domain_from_id(orig_sender) == expected_domain:
+ # This redaction event is allowed. Mark as not needing a
+ # recheck.
+ entry.event.internal_metadata.recheck_redaction = False
+ else:
+ # We don't have the event that is being redacted, so we
+ # assume that the event isn't authorized for now. (If we
+ # later receive the event, then we will always redact
+ # it anyway, since we have this redaction)
+ continue
+
if allow_rejected or not entry.event.rejected_reason:
if check_redacted and entry.redacted_event:
event = entry.redacted_event
@@ -197,7 +247,7 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue(events)
def _invalidate_get_event_cache(self, event_id):
- self._get_event_cache.invalidate((event_id,))
+ self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
"""Fetch events from the caches
@@ -310,7 +360,7 @@ class EventsWorkerStore(SQLBaseStore):
self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks
- def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
+ def _enqueue_events(self, events, allow_rejected=False):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -353,6 +403,7 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_from_row,
row["internal_metadata"], row["json"], row["redacts"],
rejected_reason=row["rejects"],
+ format_version=row["format_version"],
)
for row in rows
],
@@ -377,6 +428,7 @@ class EventsWorkerStore(SQLBaseStore):
" e.event_id as event_id, "
" e.internal_metadata,"
" e.json,"
+ " e.format_version, "
" r.redacts as redacts,"
" rej.event_id as rejects "
" FROM event_json as e"
@@ -392,7 +444,7 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted,
- rejected_reason=None):
+ format_version, rejected_reason=None):
with Measure(self._clock, "_get_event_from_row"):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
@@ -405,8 +457,13 @@ class EventsWorkerStore(SQLBaseStore):
desc="_get_event_from_row_rejected_reason",
)
- original_ev = FrozenEvent(
- d,
+ if format_version is None:
+ # This means that we stored the event before we had the concept
+ # of a event format version, so it must be a V1 event.
+ format_version = EventFormatVersions.V1
+
+ original_ev = event_type_from_format_version(format_version)(
+ event_dict=d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
@@ -436,6 +493,19 @@ class EventsWorkerStore(SQLBaseStore):
# will serialise this field correctly
redacted_event.unsigned["redacted_because"] = because
+ # Starting in room version v3, some redactions need to be
+ # rechecked if we didn't have the redacted event at the
+ # time, so we recheck on read instead.
+ if because.internal_metadata.need_to_check_redaction():
+ expected_domain = get_domain_from_id(original_ev.sender)
+ if get_domain_from_id(because.sender) == expected_domain:
+ # This redaction event is allowed. Mark as not needing a
+ # recheck.
+ because.internal_metadata.recheck_redaction = False
+ else:
+ # Senders don't match, so the event isn't actually redacted
+ redacted_event = None
+
cache_entry = _EventCacheEntry(
event=original_ev,
redacted_event=redacted_event,
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index d6fc8edd4c..9e7e09b8c1 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -197,15 +197,21 @@ class MonthlyActiveUsersStore(SQLBaseStore):
if is_support:
return
- is_insert = yield self.runInteraction(
+ yield self.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn,
user_id
)
- if is_insert:
- self.user_last_seen_monthly_active.invalidate((user_id,))
+ user_in_mau = self.user_last_seen_monthly_active.cache.get(
+ (user_id,),
+ None,
+ update_metrics=False
+ )
+ if user_in_mau is None:
self.get_monthly_active_count.invalidate(())
+ self.user_last_seen_monthly_active.invalidate((user_id,))
+
def upsert_monthly_active_user_txn(self, txn, user_id):
"""Updates or inserts monthly active user member
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 2743b52bad..134297e284 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -215,7 +215,7 @@ class PusherStore(PusherWorkerStore):
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so _simple_upsert will retry
- newly_inserted = yield self._simple_upsert(
+ yield self._simple_upsert(
table="pushers",
keyvalues={
"app_id": app_id,
@@ -238,7 +238,12 @@ class PusherStore(PusherWorkerStore):
lock=False,
)
- if newly_inserted:
+ user_has_pusher = self.get_if_user_has_pusher.cache.get(
+ (user_id,), None, update_metrics=False
+ )
+
+ if user_has_pusher is not True:
+ # invalidate, since we the user might not have had a pusher before
yield self.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 0707f9a86a..592c1bcd33 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -588,12 +588,12 @@ class RoomMemberStore(RoomMemberWorkerStore):
)
# We update the local_invites table only if the event is "current",
- # i.e., its something that has just happened.
- # The only current event that can also be an outlier is if its an
- # invite that has come in across federation.
+ # i.e., its something that has just happened. If the event is an
+ # outlier it is only current if its an "out of band membership",
+ # like a remote invite or a rejection of a remote invite.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
- or event.internal_metadata.is_invite_from_remote()
+ or event.internal_metadata.is_out_of_band_membership()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
diff --git a/synapse/storage/schema/delta/53/event_format_version.sql b/synapse/storage/schema/delta/53/event_format_version.sql
new file mode 100644
index 0000000000..1d977c2834
--- /dev/null
+++ b/synapse/storage/schema/delta/53/event_format_version.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE event_json ADD COLUMN format_version INTEGER;
diff --git a/synapse/storage/schema/delta/53/user_ips_index.sql b/synapse/storage/schema/delta/53/user_ips_index.sql
index 4ca346c111..b812c5794f 100644
--- a/synapse/storage/schema/delta/53/user_ips_index.sql
+++ b/synapse/storage/schema/delta/53/user_ips_index.sql
@@ -13,9 +13,13 @@
* limitations under the License.
*/
--- delete duplicates
+ -- analyze user_ips, to help ensure the correct indices are used
INSERT INTO background_updates (update_name, progress_json) VALUES
- ('user_ips_remove_dupes', '{}');
+ ('user_ips_analyze', '{}');
+
+-- delete duplicates
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('user_ips_remove_dupes', '{}', 'user_ips_analyze');
-- add a new unique index to user_ips table
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
@@ -23,4 +27,4 @@ INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
-- drop the old original index
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
- ('user_ips_drop_nonunique_index', '{}', 'user_ips_device_unique_index');
\ No newline at end of file
+ ('user_ips_drop_nonunique_index', '{}', 'user_ips_device_unique_index');
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index a134e9b3e8..d14a7b2538 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -428,14 +428,54 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
# for now we do this by looking at the create event. We may want to cache this
# more intelligently in future.
+
+ # Retrieve the room's create event
+ create_event = yield self.get_create_event_for_room(room_id)
+ defer.returnValue(create_event.content.get("room_version", "1"))
+
+ @defer.inlineCallbacks
+ def get_room_predecessor(self, room_id):
+ """Get the predecessor room of an upgraded room if one exists.
+ Otherwise return None.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[unicode|None]: predecessor room id
+
+ Raises:
+ NotFoundError if the room is unknown
+ """
+ # Retrieve the room's create event
+ create_event = yield self.get_create_event_for_room(room_id)
+
+ # Return predecessor if present
+ defer.returnValue(create_event.content.get("predecessor", None))
+
+ @defer.inlineCallbacks
+ def get_create_event_for_room(self, room_id):
+ """Get the create state event for a room.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[EventBase]: The room creation event.
+
+ Raises:
+ NotFoundError if the room is unknown
+ """
state_ids = yield self.get_current_state_ids(room_id)
create_id = state_ids.get((EventTypes.Create, ""))
+ # If we can't find the create event, assume we've hit a dead end
if not create_id:
raise NotFoundError("Unknown room %s" % (room_id))
+ # Retrieve the room's create event and return
create_event = yield self.get_event(create_id)
- defer.returnValue(create_event.content.get("room_version", "1"))
+ defer.returnValue(create_event)
@cached(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id):
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index a8781b0e5d..fea866c043 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -22,6 +22,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.state import StateFilter
from synapse.types import get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -31,12 +32,19 @@ logger = logging.getLogger(__name__)
class UserDirectoryStore(SQLBaseStore):
- @cachedInlineCallbacks(cache_context=True)
- def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context):
+ @defer.inlineCallbacks
+ def is_room_world_readable_or_publicly_joinable(self, room_id):
"""Check if the room is either world_readable or publically joinable
"""
- current_state_ids = yield self.get_current_state_ids(
- room_id, on_invalidate=cache_context.invalidate
+
+ # Create a state filter that only queries join and history state event
+ types_to_filter = (
+ (EventTypes.JoinRules, ""),
+ (EventTypes.RoomHistoryVisibility, ""),
+ )
+
+ current_state_ids = yield self.get_filtered_current_state_ids(
+ room_id, StateFilter.from_types(types_to_filter)
)
join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
@@ -66,14 +74,8 @@ class UserDirectoryStore(SQLBaseStore):
"""
yield self._simple_insert_many(
table="users_in_public_rooms",
- values=[
- {
- "user_id": user_id,
- "room_id": room_id,
- }
- for user_id in user_ids
- ],
- desc="add_users_to_public_room"
+ values=[{"user_id": user_id, "room_id": room_id} for user_id in user_ids],
+ desc="add_users_to_public_room",
)
for user_id in user_ids:
self.get_user_in_public_room.invalidate((user_id,))
@@ -99,7 +101,9 @@ class UserDirectoryStore(SQLBaseStore):
"""
args = (
(
- user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
+ user_id,
+ get_localpart_from_id(user_id),
+ get_domain_from_id(user_id),
profile.display_name,
)
for user_id, profile in iteritems(users_with_profile)
@@ -112,7 +116,7 @@ class UserDirectoryStore(SQLBaseStore):
args = (
(
user_id,
- "%s %s" % (user_id, p.display_name,) if p.display_name else user_id
+ "%s %s" % (user_id, p.display_name) if p.display_name else user_id,
)
for user_id, p in iteritems(users_with_profile)
)
@@ -133,12 +137,10 @@ class UserDirectoryStore(SQLBaseStore):
"avatar_url": profile.avatar_url,
}
for user_id, profile in iteritems(users_with_profile)
- ]
+ ],
)
for user_id in users_with_profile:
- txn.call_after(
- self.get_user_in_directory.invalidate, (user_id,)
- )
+ txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.runInteraction(
"add_profiles_to_user_dir", _add_profiles_to_user_dir_txn
@@ -168,39 +170,69 @@ class UserDirectoryStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine):
# We weight the localpart most highly, then display name and finally
# server name
- if new_entry:
+ if self.database_engine.can_native_upsert:
sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
setweight(to_tsvector('english', ?), 'A')
|| setweight(to_tsvector('english', ?), 'D')
|| setweight(to_tsvector('english', COALESCE(?, '')), 'B')
- )
+ ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
"""
txn.execute(
sql,
(
- user_id, get_localpart_from_id(user_id),
- get_domain_from_id(user_id), display_name,
- )
+ user_id,
+ get_localpart_from_id(user_id),
+ get_domain_from_id(user_id),
+ display_name,
+ ),
)
else:
- sql = """
- UPDATE user_directory_search
- SET vector = setweight(to_tsvector('english', ?), 'A')
- || setweight(to_tsvector('english', ?), 'D')
- || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
- WHERE user_id = ?
- """
- txn.execute(
- sql,
- (
- get_localpart_from_id(user_id), get_domain_from_id(user_id),
- display_name, user_id,
+ # TODO: Remove this code after we've bumped the minimum version
+ # of postgres to always support upserts, so we can get rid of
+ # `new_entry` usage
+ if new_entry is True:
+ sql = """
+ INSERT INTO user_directory_search(user_id, vector)
+ VALUES (?,
+ setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ )
+ """
+ txn.execute(
+ sql,
+ (
+ user_id,
+ get_localpart_from_id(user_id),
+ get_domain_from_id(user_id),
+ display_name,
+ ),
+ )
+ elif new_entry is False:
+ sql = """
+ UPDATE user_directory_search
+ SET vector = setweight(to_tsvector('english', ?), 'A')
+ || setweight(to_tsvector('english', ?), 'D')
+ || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+ WHERE user_id = ?
+ """
+ txn.execute(
+ sql,
+ (
+ get_localpart_from_id(user_id),
+ get_domain_from_id(user_id),
+ display_name,
+ user_id,
+ ),
+ )
+ else:
+ raise RuntimeError(
+ "upsert returned None when 'can_native_upsert' is False"
)
- )
elif isinstance(self.database_engine, Sqlite3Engine):
- value = "%s %s" % (user_id, display_name,) if display_name else user_id
+ value = "%s %s" % (user_id, display_name) if display_name else user_id
self._simple_upsert_txn(
txn,
table="user_directory_search",
@@ -231,29 +263,18 @@ class UserDirectoryStore(SQLBaseStore):
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
self._simple_delete_txn(
- txn,
- table="user_directory",
- keyvalues={"user_id": user_id},
+ txn, table="user_directory", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
- txn,
- table="user_directory_search",
- keyvalues={"user_id": user_id},
+ txn, table="user_directory_search", keyvalues={"user_id": user_id}
)
self._simple_delete_txn(
- txn,
- table="users_in_public_rooms",
- keyvalues={"user_id": user_id},
- )
- txn.call_after(
- self.get_user_in_directory.invalidate, (user_id,)
+ txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
)
- txn.call_after(
- self.get_user_in_public_room.invalidate, (user_id,)
- )
- return self.runInteraction(
- "remove_from_user_dir", _remove_from_user_dir_txn,
- )
+ txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
+ txn.call_after(self.get_user_in_public_room.invalidate, (user_id,))
+
+ return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
@defer.inlineCallbacks
def remove_from_user_in_public_room(self, user_id):
@@ -338,6 +359,7 @@ class UserDirectoryStore(SQLBaseStore):
share_private (bool): Is the room private
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
"""
+
def _add_users_who_share_room_txn(txn):
self._simple_insert_many_txn(
txn,
@@ -354,13 +376,12 @@ class UserDirectoryStore(SQLBaseStore):
)
for user_id, other_user_id in user_id_tuples:
txn.call_after(
- self.get_users_who_share_room_from_dir.invalidate,
- (user_id,),
+ self.get_users_who_share_room_from_dir.invalidate, (user_id,)
)
txn.call_after(
- self.get_if_users_share_a_room.invalidate,
- (user_id, other_user_id),
+ self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
)
+
return self.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
@@ -374,6 +395,7 @@ class UserDirectoryStore(SQLBaseStore):
share_private (bool): Is the room private
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
"""
+
def _update_users_who_share_room_txn(txn):
sql = """
UPDATE users_who_share_rooms
@@ -381,21 +403,16 @@ class UserDirectoryStore(SQLBaseStore):
WHERE user_id = ? AND other_user_id = ?
"""
txn.executemany(
- sql,
- (
- (room_id, share_private, uid, oid)
- for uid, oid in user_id_sets
- )
+ sql, ((room_id, share_private, uid, oid) for uid, oid in user_id_sets)
)
for user_id, other_user_id in user_id_sets:
txn.call_after(
- self.get_users_who_share_room_from_dir.invalidate,
- (user_id,),
+ self.get_users_who_share_room_from_dir.invalidate, (user_id,)
)
txn.call_after(
- self.get_if_users_share_a_room.invalidate,
- (user_id, other_user_id),
+ self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
)
+
return self.runInteraction(
"update_users_who_share_room", _update_users_who_share_room_txn
)
@@ -409,22 +426,18 @@ class UserDirectoryStore(SQLBaseStore):
share_private (bool): Is the room private
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
"""
+
def _remove_user_who_share_room_txn(txn):
self._simple_delete_txn(
txn,
table="users_who_share_rooms",
- keyvalues={
- "user_id": user_id,
- "other_user_id": other_user_id,
- },
+ keyvalues={"user_id": user_id, "other_user_id": other_user_id},
)
txn.call_after(
- self.get_users_who_share_room_from_dir.invalidate,
- (user_id,),
+ self.get_users_who_share_room_from_dir.invalidate, (user_id,)
)
txn.call_after(
- self.get_if_users_share_a_room.invalidate,
- (user_id, other_user_id),
+ self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
)
return self.runInteraction(
@@ -445,10 +458,7 @@ class UserDirectoryStore(SQLBaseStore):
"""
return self._simple_select_one_onecol(
table="users_who_share_rooms",
- keyvalues={
- "user_id": user_id,
- "other_user_id": other_user_id,
- },
+ keyvalues={"user_id": user_id, "other_user_id": other_user_id},
retcol="share_private",
allow_none=True,
desc="get_if_users_share_a_room",
@@ -466,17 +476,12 @@ class UserDirectoryStore(SQLBaseStore):
"""
rows = yield self._simple_select_list(
table="users_who_share_rooms",
- keyvalues={
- "user_id": user_id,
- },
- retcols=("other_user_id", "share_private",),
+ keyvalues={"user_id": user_id},
+ retcols=("other_user_id", "share_private"),
desc="get_users_who_share_room_with_user",
)
- defer.returnValue({
- row["other_user_id"]: row["share_private"]
- for row in rows
- })
+ defer.returnValue({row["other_user_id"]: row["share_private"] for row in rows})
def get_users_in_share_dir_with_room_id(self, user_id, room_id):
"""Get all user tuples that are in the users_who_share_rooms due to the
@@ -523,6 +528,7 @@ class UserDirectoryStore(SQLBaseStore):
def delete_all_from_user_dir(self):
"""Delete the entire user directory
"""
+
def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory")
txn.execute("DELETE FROM user_directory_search")
@@ -532,6 +538,7 @@ class UserDirectoryStore(SQLBaseStore):
txn.call_after(self.get_user_in_public_room.invalidate_all)
txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all)
txn.call_after(self.get_if_users_share_a_room.invalidate_all)
+
return self.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@@ -541,7 +548,7 @@ class UserDirectoryStore(SQLBaseStore):
return self._simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
- retcols=("room_id", "display_name", "avatar_url",),
+ retcols=("room_id", "display_name", "avatar_url"),
allow_none=True,
desc="get_user_in_directory",
)
@@ -574,7 +581,9 @@ class UserDirectoryStore(SQLBaseStore):
def get_current_state_deltas(self, prev_stream_id):
prev_stream_id = int(prev_stream_id)
- if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
+ if not self._curr_state_delta_stream_cache.has_any_entity_changed(
+ prev_stream_id
+ ):
return []
def get_current_state_deltas_txn(txn):
@@ -608,7 +617,7 @@ class UserDirectoryStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
- txn.execute(sql, (prev_stream_id, max_stream_id,))
+ txn.execute(sql, (prev_stream_id, max_stream_id))
return self.cursor_to_dict(txn)
return self.runInteraction(
@@ -698,8 +707,11 @@ class UserDirectoryStore(SQLBaseStore):
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
- """ % (join_clause, where_clause)
- args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
+ """ % (
+ join_clause,
+ where_clause,
+ )
+ args = join_args + (full_query, exact_query, prefix_query, limit + 1)
elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_sqlite(search_term)
@@ -716,7 +728,10 @@ class UserDirectoryStore(SQLBaseStore):
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
- """ % (join_clause, where_clause)
+ """ % (
+ join_clause,
+ where_clause,
+ )
args = join_args + (search_query, limit + 1)
else:
# This should be unreachable.
@@ -728,10 +743,7 @@ class UserDirectoryStore(SQLBaseStore):
limited = len(results) > limit
- defer.returnValue({
- "limited": limited,
- "results": results,
- })
+ defer.returnValue({"limited": limited, "results": results})
def _parse_query_sqlite(search_term):
@@ -746,7 +758,7 @@ def _parse_query_sqlite(search_term):
# Pull out the individual words, discarding any non-word characters.
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
- return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
+ return " & ".join("(%s* OR %s)" % (result, result) for result in results)
def _parse_query_postgres(search_term):
@@ -759,7 +771,7 @@ def _parse_query_postgres(search_term):
# Pull out the individual words, discarding any non-word characters.
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
- both = " & ".join("(%s:* | %s)" % (result, result,) for result in results)
+ both = " & ".join("(%s:* | %s)" % (result, result) for result in results)
exact = " & ".join("%s" % (result,) for result in results)
prefix = " & ".join("%s:*" % (result,) for result in results)
|