diff options
Diffstat (limited to 'synapse/storage')
30 files changed, 1309 insertions, 883 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index abe16334ec..f5906fcd54 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -20,6 +20,7 @@ import random import sys import threading import time +from typing import Iterable, Tuple from six import PY2, iteritems, iterkeys, itervalues from six.moves import builtins, intern, range @@ -30,7 +31,7 @@ from prometheus_client import Histogram from twisted.internet import defer from synapse.api.errors import StoreError -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id @@ -550,8 +551,9 @@ class SQLBaseStore(object): return func(conn, *args, **kwargs) - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs) + result = yield make_deferred_yieldable( + self._db_pool.runWithConnection(inner_func, *args, **kwargs) + ) return result @@ -1162,19 +1164,18 @@ class SQLBaseStore(object): if not iterable: return [] - sql = "SELECT %s FROM %s" % (", ".join(retcols), table) - - clauses = [] - values = [] - clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable))) - values.extend(iterable) + clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) + clauses = [clause] for key, value in iteritems(keyvalues): clauses.append("%s = ?" % (key,)) values.append(value) - if clauses: - sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join(clauses), + ) txn.execute(sql, values) return cls.cursor_to_dict(txn) @@ -1323,10 +1324,8 @@ class SQLBaseStore(object): sql = "DELETE FROM %s" % table - clauses = [] - values = [] - clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable))) - values.extend(iterable) + clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) + clauses = [clause] for key, value in iteritems(keyvalues): clauses.append("%s = ?" % (key,)) @@ -1693,3 +1692,30 @@ def db_to_json(db_content): except Exception: logging.warning("Tried to decode '%r' as JSON and failed", db_content) raise + + +def make_in_list_sql_clause( + database_engine, column: str, iterable: Iterable +) -> Tuple[str, Iterable]: + """Returns an SQL clause that checks the given column is in the iterable. + + On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres + it expands to `column = ANY(?)`. While both DBs support the `IN` form, + using the `ANY` form on postgres means that it views queries with + different length iterables as the same, helping the query stats. + + Args: + database_engine + column: Name of the column + iterable: The values to check the column against. + + Returns: + A tuple of SQL query and the args + """ + + if database_engine.supports_using_any_list: + # This should hopefully be faster, but also makes postgres query + # stats easier to understand. + return "%s = ANY(?)" % (column,), [list(iterable)] + else: + return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index bb135166ce..067820a5da 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -33,16 +33,9 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 120 * 1000 -class ClientIpStore(background_updates.BackgroundUpdateStore): +class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): def __init__(self, db_conn, hs): - - self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR - ) - - super(ClientIpStore, self).__init__(db_conn, hs) - - self.user_ips_max_age = hs.config.user_ips_max_age + super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs) self.register_background_index_update( "user_ips_device_index", @@ -92,19 +85,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): "devices_last_seen", self._devices_last_seen_update ) - # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) - self._batch_row_update = {} - - self._client_ip_looper = self._clock.looping_call( - self._update_client_ips_batch, 5 * 1000 - ) - self.hs.get_reactor().addSystemEventTrigger( - "before", "shutdown", self._update_client_ips_batch - ) - - if self.user_ips_max_age: - self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) - @defer.inlineCallbacks def _remove_user_ip_nonunique(self, progress, batch_size): def f(conn): @@ -304,6 +284,110 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): return batch_size @defer.inlineCallbacks + def _devices_last_seen_update(self, progress, batch_size): + """Background update to insert last seen info into devices table + """ + + last_user_id = progress.get("last_user_id", "") + last_device_id = progress.get("last_device_id", "") + + def _devices_last_seen_update_txn(txn): + # This consists of two queries: + # + # 1. The sub-query searches for the next N devices and joins + # against user_ips to find the max last_seen associated with + # that device. + # 2. The outer query then joins again against user_ips on + # user/device/last_seen. This *should* hopefully only + # return one row, but if it does return more than one then + # we'll just end up updating the same device row multiple + # times, which is fine. + + if self.database_engine.supports_tuple_comparison: + where_clause = "(user_id, device_id) > (?, ?)" + where_args = [last_user_id, last_device_id] + else: + # We explicitly do a `user_id >= ? AND (...)` here to ensure + # that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)` + # makes it hard for query optimiser to tell that it can use the + # index on user_id + where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)" + where_args = [last_user_id, last_user_id, last_device_id] + + sql = """ + SELECT + last_seen, ip, user_agent, user_id, device_id + FROM ( + SELECT + user_id, device_id, MAX(u.last_seen) AS last_seen + FROM devices + INNER JOIN user_ips AS u USING (user_id, device_id) + WHERE %(where_clause)s + GROUP BY user_id, device_id + ORDER BY user_id ASC, device_id ASC + LIMIT ? + ) c + INNER JOIN user_ips AS u USING (user_id, device_id, last_seen) + """ % { + "where_clause": where_clause + } + txn.execute(sql, where_args + [batch_size]) + + rows = txn.fetchall() + if not rows: + return 0 + + sql = """ + UPDATE devices + SET last_seen = ?, ip = ?, user_agent = ? + WHERE user_id = ? AND device_id = ? + """ + txn.execute_batch(sql, rows) + + _, _, _, user_id, device_id = rows[-1] + self._background_update_progress_txn( + txn, + "devices_last_seen", + {"last_user_id": user_id, "last_device_id": device_id}, + ) + + return len(rows) + + updated = yield self.runInteraction( + "_devices_last_seen_update", _devices_last_seen_update_txn + ) + + if not updated: + yield self._end_background_update("devices_last_seen") + + return updated + + +class ClientIpStore(ClientIpBackgroundUpdateStore): + def __init__(self, db_conn, hs): + + self.client_ip_last_seen = Cache( + name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR + ) + + super(ClientIpStore, self).__init__(db_conn, hs) + + self.user_ips_max_age = hs.config.user_ips_max_age + + # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) + self._batch_row_update = {} + + self._client_ip_looper = self._clock.looping_call( + self._update_client_ips_batch, 5 * 1000 + ) + self.hs.get_reactor().addSystemEventTrigger( + "before", "shutdown", self._update_client_ips_batch + ) + + if self.user_ips_max_age: + self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) + + @defer.inlineCallbacks def insert_client_ip( self, user_id, access_token, ip, user_agent, device_id, now=None ): @@ -454,85 +538,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): for (access_token, ip), (user_agent, last_seen) in iteritems(results) ) - @defer.inlineCallbacks - def _devices_last_seen_update(self, progress, batch_size): - """Background update to insert last seen info into devices table - """ - - last_user_id = progress.get("last_user_id", "") - last_device_id = progress.get("last_device_id", "") - - def _devices_last_seen_update_txn(txn): - # This consists of two queries: - # - # 1. The sub-query searches for the next N devices and joins - # against user_ips to find the max last_seen associated with - # that device. - # 2. The outer query then joins again against user_ips on - # user/device/last_seen. This *should* hopefully only - # return one row, but if it does return more than one then - # we'll just end up updating the same device row multiple - # times, which is fine. - - if self.database_engine.supports_tuple_comparison: - where_clause = "(user_id, device_id) > (?, ?)" - where_args = [last_user_id, last_device_id] - else: - # We explicitly do a `user_id >= ? AND (...)` here to ensure - # that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)` - # makes it hard for query optimiser to tell that it can use the - # index on user_id - where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)" - where_args = [last_user_id, last_user_id, last_device_id] - - sql = """ - SELECT - last_seen, ip, user_agent, user_id, device_id - FROM ( - SELECT - user_id, device_id, MAX(u.last_seen) AS last_seen - FROM devices - INNER JOIN user_ips AS u USING (user_id, device_id) - WHERE %(where_clause)s - GROUP BY user_id, device_id - ORDER BY user_id ASC, device_id ASC - LIMIT ? - ) c - INNER JOIN user_ips AS u USING (user_id, device_id, last_seen) - """ % { - "where_clause": where_clause - } - txn.execute(sql, where_args + [batch_size]) - - rows = txn.fetchall() - if not rows: - return 0 - - sql = """ - UPDATE devices - SET last_seen = ?, ip = ?, user_agent = ? - WHERE user_id = ? AND device_id = ? - """ - txn.execute_batch(sql, rows) - - _, _, _, user_id, device_id = rows[-1] - self._background_update_progress_txn( - txn, - "devices_last_seen", - {"last_user_id": user_id, "last_device_id": device_id}, - ) - - return len(rows) - - updated = yield self.runInteraction( - "_devices_last_seen_update", _devices_last_seen_update_txn - ) - - if not updated: - yield self._end_background_update("devices_last_seen") - - return updated - @wrap_as_background_process("prune_old_user_ips") async def _prune_old_user_ips(self): """Removes entries in user IPs older than the configured period. diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 6b7458304e..f04aad0743 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -20,7 +20,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.background_updates import BackgroundUpdateStore from synapse.util.caches.expiringcache import ExpiringCache @@ -208,11 +208,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) -class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): +class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" def __init__(self, db_conn, hs): - super(DeviceInboxStore, self).__init__(db_conn, hs) + super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs) self.register_background_index_update( "device_inbox_stream_index", @@ -225,6 +225,26 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) + @defer.inlineCallbacks + def _background_drop_index_device_inbox(self, progress, batch_size): + def reindex_txn(conn): + txn = conn.cursor() + txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") + txn.close() + + yield self.runWithConnection(reindex_txn) + + yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) + + return 1 + + +class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, db_conn, hs): + super(DeviceInboxStore, self).__init__(db_conn, hs) + # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. self._last_device_delete_cache = ExpiringCache( @@ -358,15 +378,15 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): else: if not devices: continue - sql = ( - "SELECT device_id FROM devices" - " WHERE user_id = ? AND device_id IN (" - + ",".join("?" * len(devices)) - + ")" + + clause, args = make_in_list_sql_clause( + txn.database_engine, "device_id", devices ) + sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause + # TODO: Maybe this needs to be done in batches if there are # too many local devices for a given user. - txn.execute(sql, [user_id] + devices) + txn.execute(sql, [user_id] + list(args)) for row in txn: # Only insert into the local inbox if the device exists on # this server @@ -435,16 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): return self.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn ) - - @defer.inlineCallbacks - def _background_drop_index_device_inbox(self, progress, batch_size): - def reindex_txn(conn): - txn = conn.cursor() - txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") - txn.close() - - yield self.runWithConnection(reindex_txn) - - yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) - - return 1 diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 79a58df591..ac5239e50a 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -28,7 +28,12 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import Cache, SQLBaseStore, db_to_json +from synapse.storage._base import ( + Cache, + SQLBaseStore, + db_to_json, + make_in_list_sql_clause, +) from synapse.storage.background_updates import BackgroundUpdateStore from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList @@ -448,11 +453,14 @@ class DeviceWorkerStore(SQLBaseStore): sql = """ SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? - AND user_id IN (%s) + AND """ for chunk in batch_iter(to_check, 100): - txn.execute(sql % (",".join("?" for _ in chunk),), (from_key,) + chunk) + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", chunk + ) + txn.execute(sql + clause, (from_key,) + tuple(args)) changes.update(user_id for user_id, in txn) return changes @@ -512,17 +520,9 @@ class DeviceWorkerStore(SQLBaseStore): return results -class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): +class DeviceBackgroundUpdateStore(BackgroundUpdateStore): def __init__(self, db_conn, hs): - super(DeviceStore, self).__init__(db_conn, hs) - - # Map of (user_id, device_id) -> bool. If there is an entry that implies - # the device exists. - self.device_id_exists_cache = Cache( - name="device_id_exists", keylen=2, max_entries=10000 - ) - - self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) + super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs) self.register_background_index_update( "device_lists_stream_idx", @@ -556,6 +556,31 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): ) @defer.inlineCallbacks + def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): + def f(conn): + txn = conn.cursor() + txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") + txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") + txn.close() + + yield self.runWithConnection(f) + yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES) + return 1 + + +class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(DeviceStore, self).__init__(db_conn, hs) + + # Map of (user_id, device_id) -> bool. If there is an entry that implies + # the device exists. + self.device_id_exists_cache = Cache( + name="device_id_exists", keylen=2, max_entries=10000 + ) + + self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) + + @defer.inlineCallbacks def store_device(self, user_id, device_id, initial_device_display_name): """Ensure the given device is known; add it to the store if not @@ -910,15 +935,3 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): "_prune_old_outbound_device_pokes", _prune_txn, ) - - @defer.inlineCallbacks - def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): - def f(conn): - txn = conn.cursor() - txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") - txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") - txn.close() - - yield self.runWithConnection(f) - yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES) - return 1 diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 33e3a84933..872bc75490 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -40,7 +40,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore): This option only takes effect if include_all_devices is true. Returns: Dict mapping from user-id to dict mapping from device_id to - dict containing "key_json", "device_display_name". + key data. The key data will be a dict in the same format as the + DeviceKeys type returned by POST /_matrix/client/r0/keys/query. """ set_tag("query_list", query_list) if not query_list: @@ -54,11 +55,20 @@ class EndToEndKeyWorkerStore(SQLBaseStore): include_deleted_devices, ) + # Build the result structure, un-jsonify the results, and add the + # "unsigned" section + rv = {} for user_id, device_keys in iteritems(results): + rv[user_id] = {} for device_id, device_info in iteritems(device_keys): - device_info["keys"] = db_to_json(device_info.pop("key_json")) - - return results + r = db_to_json(device_info.pop("key_json")) + r["unsigned"] = {} + display_name = device_info["device_display_name"] + if display_name is not None: + r["unsigned"]["device_display_name"] = display_name + rv[user_id][device_id] = r + + return rv @trace def _get_e2e_device_keys_txn( diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 601617b21e..b7c4eda338 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -22,6 +22,13 @@ class PostgresEngine(object): def __init__(self, database_module, database_config): self.module = database_module self.module.extensions.register_type(self.module.extensions.UNICODE) + + # Disables passing `bytes` to txn.execute, c.f. #6186. If you do + # actually want to use bytes than wrap it in `bytearray`. + def _disable_bytes_adapter(_): + raise Exception("Passing bytes to DB is disabled.") + + self.module.extensions.register_adapter(bytes, _disable_bytes_adapter) self.synchronous_commit = database_config.get("synchronous_commit", True) self._version = None # unknown as yet @@ -79,6 +86,12 @@ class PostgresEngine(object): """ return True + @property + def supports_using_any_list(self): + """Do we support using `a = ANY(?)` and passing a list + """ + return True + def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index ac92109366..ddad17dc5a 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -46,6 +46,12 @@ class Sqlite3Engine(object): """ return self.module.sqlite_version_info >= (3, 15, 0) + @property + def supports_using_any_list(self): + """Do we support using `a = ANY(?)` and passing a list + """ + return False + def check_database(self, txn): pass diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index f5e8c39262..47cc10d32a 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -25,7 +25,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.signatures import SignatureWorkerStore from synapse.util.caches.descriptors import cached @@ -68,7 +68,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas else: results = set() - base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)" + base_sql = "SELECT auth_id FROM event_auth WHERE " front = set(event_ids) while front: @@ -76,7 +76,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas front_list = list(front) chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)] for chunk in chunks: - txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk) + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", chunk + ) + txn.execute(base_sql + clause, list(args)) new_front.update([r[0] for r in txn]) new_front -= results diff --git a/synapse/storage/events.py b/synapse/storage/events.py index bb6ff0595a..ee49ef235d 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -39,6 +39,7 @@ from synapse.logging.utils import log_function from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateResolutionStore +from synapse.storage._base import make_in_list_sql_clause from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.event_federation import EventFederationStore from synapse.storage.events_worker import EventsWorkerStore @@ -641,14 +642,16 @@ class EventsStore( LEFT JOIN rejections USING (event_id) LEFT JOIN event_json USING (event_id) WHERE - prev_event_id IN (%s) - AND NOT events.outlier + NOT events.outlier AND rejections.event_id IS NULL - """ % ( - ",".join("?" for _ in batch), + AND + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "prev_event_id", batch ) - txn.execute(sql, batch) + txn.execute(sql + clause, args) results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) for chunk in batch_iter(event_ids, 100): @@ -695,13 +698,15 @@ class EventsStore( LEFT JOIN rejections USING (event_id) LEFT JOIN event_json USING (event_id) WHERE - event_id IN (%s) - AND NOT events.outlier - """ % ( - ",".join("?" for _ in to_recursively_check), + NOT events.outlier + AND + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "event_id", to_recursively_check ) - txn.execute(sql, to_recursively_check) + txn.execute(sql + clause, args) to_recursively_check = [] for event_id, prev_event_id, metadata, rejected in txn: @@ -1543,10 +1548,14 @@ class EventsStore( " FROM events as e" " LEFT JOIN rejections as rej USING (event_id)" " LEFT JOIN redactions as r ON e.event_id = r.redacts" - " WHERE e.event_id IN (%s)" - ) % (",".join(["?"] * len(ev_map)),) + " WHERE " + ) + + clause, args = make_in_list_sql_clause( + self.database_engine, "e.event_id", list(ev_map) + ) - txn.execute(sql, list(ev_map)) + txn.execute(sql + clause, args) rows = self.cursor_to_dict(txn) for row in rows: event = ev_map[row["event_id"]] @@ -2249,11 +2258,12 @@ class EventsStore( sql = """ SELECT DISTINCT state_group FROM event_to_state_groups LEFT JOIN events_to_purge AS ep USING (event_id) - WHERE state_group IN (%s) AND ep.event_id IS NULL - """ % ( - ",".join("?" for _ in current_search), + WHERE ep.event_id IS NULL AND + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "state_group", current_search ) - txn.execute(sql, list(current_search)) + txn.execute(sql + clause, list(args)) referenced = set(sg for sg, in txn) referenced_groups |= referenced diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py index e77a7e28af..31ea6f917f 100644 --- a/synapse/storage/events_bg_updates.py +++ b/synapse/storage/events_bg_updates.py @@ -21,6 +21,7 @@ from canonicaljson import json from twisted.internet import defer +from synapse.storage._base import make_in_list_sql_clause from synapse.storage.background_updates import BackgroundUpdateStore logger = logging.getLogger(__name__) @@ -325,12 +326,13 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): INNER JOIN event_json USING (event_id) LEFT JOIN rejections USING (event_id) WHERE - prev_event_id IN (%s) - AND NOT events.outlier - """ % ( - ",".join("?" for _ in to_check), + NOT events.outlier + AND + """ + clause, args = make_in_list_sql_clause( + self.database_engine, "prev_event_id", to_check ) - txn.execute(sql, to_check) + txn.execute(sql + clause, list(args)) for prev_event_id, event_id, metadata, rejected in txn: if event_id in graph: diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 57ce0304e9..4c4b76bd93 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -31,12 +31,11 @@ from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.utils import prune_event from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.types import get_domain_from_id from synapse.util import batch_iter from synapse.util.metrics import Measure -from ._base import SQLBaseStore - logger = logging.getLogger(__name__) @@ -623,10 +622,14 @@ class EventsWorkerStore(SQLBaseStore): " rej.reason " " FROM event_json as e" " LEFT JOIN rejections as rej USING (event_id)" - " WHERE e.event_id IN (%s)" - ) % (",".join(["?"] * len(evs)),) + " WHERE " + ) - txn.execute(sql, evs) + clause, args = make_in_list_sql_clause( + txn.database_engine, "e.event_id", evs + ) + + txn.execute(sql + clause, args) for row in txn: event_id = row[0] @@ -640,11 +643,11 @@ class EventsWorkerStore(SQLBaseStore): } # check for redactions - redactions_sql = ( - "SELECT event_id, redacts FROM redactions WHERE redacts IN (%s)" - ) % (",".join(["?"] * len(evs)),) + redactions_sql = "SELECT event_id, redacts FROM redactions WHERE " + + clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs) - txn.execute(redactions_sql, evs) + txn.execute(redactions_sql + clause, args) for (redacter, redacted) in txn: d = event_dict.get(redacted) @@ -753,10 +756,11 @@ class EventsWorkerStore(SQLBaseStore): results = set() def have_seen_events_txn(txn, chunk): - sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % ( - ",".join("?" * len(chunk)), + sql = "SELECT event_id FROM events as e WHERE " + clause, args = make_in_list_sql_clause( + txn.database_engine, "e.event_id", chunk ) - txn.execute(sql, chunk) + txn.execute(sql + clause, args) for (event_id,) in txn: results.add(event_id) diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index 23b48f6cea..7c2a7da836 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -51,7 +51,7 @@ class FilteringStore(SQLBaseStore): "SELECT filter_id FROM user_filters " "WHERE user_id = ? AND filter_json = ?" ) - txn.execute(sql, (user_localpart, def_json)) + txn.execute(sql, (user_localpart, bytearray(def_json))) filter_id_response = txn.fetchone() if filter_id_response is not None: return filter_id_response[0] @@ -68,7 +68,7 @@ class FilteringStore(SQLBaseStore): "INSERT INTO user_filters (user_id, filter_id, filter_json)" "VALUES(?, ?, ?)" ) - txn.execute(sql, (user_localpart, filter_id, def_json)) + txn.execute(sql, (user_localpart, filter_id, bytearray(def_json))) return filter_id diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 6b1238ce4a..84b5f3ad5e 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -15,11 +15,9 @@ from synapse.storage.background_updates import BackgroundUpdateStore -class MediaRepositoryStore(BackgroundUpdateStore): - """Persistence for attachments and avatars""" - +class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore): def __init__(self, db_conn, hs): - super(MediaRepositoryStore, self).__init__(db_conn, hs) + super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs) self.register_background_index_update( update_name="local_media_repository_url_idx", @@ -29,6 +27,13 @@ class MediaRepositoryStore(BackgroundUpdateStore): where_clause="url_cache IS NOT NULL", ) + +class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): + """Persistence for attachments and avatars""" + + def __init__(self, db_conn, hs): + super(MediaRepositoryStore, self).__init__(db_conn, hs) + def get_local_media(self, media_id): """Get the metadata for a local piece of media Returns: diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 752e9788a2..3803604be7 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -32,7 +32,6 @@ class MonthlyActiveUsersStore(SQLBaseStore): super(MonthlyActiveUsersStore, self).__init__(None, hs) self._clock = hs.get_clock() self.hs = hs - self.reserved_users = () # Do not add more reserved users than the total allowable number self._new_transaction( dbconn, @@ -51,7 +50,6 @@ class MonthlyActiveUsersStore(SQLBaseStore): txn (cursor): threepids (list[dict]): List of threepid dicts to reserve """ - reserved_user_list = [] for tp in threepids: user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) @@ -60,10 +58,8 @@ class MonthlyActiveUsersStore(SQLBaseStore): is_support = self.is_support_user_txn(txn, user_id) if not is_support: self.upsert_monthly_active_user_txn(txn, user_id) - reserved_user_list.append(user_id) else: logger.warning("mau limit reserved threepid %s not found in db" % tp) - self.reserved_users = tuple(reserved_user_list) @defer.inlineCallbacks def reap_monthly_active_users(self): @@ -74,8 +70,11 @@ class MonthlyActiveUsersStore(SQLBaseStore): Deferred[] """ - def _reap_users(txn): - # Purge stale users + def _reap_users(txn, reserved_users): + """ + Args: + reserved_users (tuple): reserved users to preserve + """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) query_args = [thirty_days_ago] @@ -83,20 +82,19 @@ class MonthlyActiveUsersStore(SQLBaseStore): # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres # when len(reserved_users) == 0. Works fine on sqlite. - if len(self.reserved_users) > 0: + if len(reserved_users) > 0: # questionmarks is a hack to overcome sqlite not supporting # tuples in 'WHERE IN %s' - questionmarks = "?" * len(self.reserved_users) + question_marks = ",".join("?" * len(reserved_users)) - query_args.extend(self.reserved_users) - sql = base_sql + """ AND user_id NOT IN ({})""".format( - ",".join(questionmarks) - ) + query_args.extend(reserved_users) + sql = base_sql + " AND user_id NOT IN ({})".format(question_marks) else: sql = base_sql txn.execute(sql, query_args) + max_mau_value = self.hs.config.max_mau_value if self.hs.config.limit_usage_by_mau: # If MAU user count still exceeds the MAU threshold, then delete on # a least recently active basis. @@ -106,31 +104,52 @@ class MonthlyActiveUsersStore(SQLBaseStore): # While Postgres does not require 'LIMIT', but also does not support # negative LIMIT values. So there is no way to write it that both can # support - safe_guard = self.hs.config.max_mau_value - len(self.reserved_users) - # Must be greater than zero for postgres - safe_guard = safe_guard if safe_guard > 0 else 0 - query_args = [safe_guard] - - base_sql = """ - DELETE FROM monthly_active_users - WHERE user_id NOT IN ( - SELECT user_id FROM monthly_active_users - ORDER BY timestamp DESC - LIMIT ? + if len(reserved_users) == 0: + sql = """ + DELETE FROM monthly_active_users + WHERE user_id NOT IN ( + SELECT user_id FROM monthly_active_users + ORDER BY timestamp DESC + LIMIT ? ) - """ + """ + txn.execute(sql, (max_mau_value,)) # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres # when len(reserved_users) == 0. Works fine on sqlite. - if len(self.reserved_users) > 0: - query_args.extend(self.reserved_users) - sql = base_sql + """ AND user_id NOT IN ({})""".format( - ",".join(questionmarks) - ) else: - sql = base_sql - txn.execute(sql, query_args) + # Must be >= 0 for postgres + num_of_non_reserved_users_to_remove = max( + max_mau_value - len(reserved_users), 0 + ) + + # It is important to filter reserved users twice to guard + # against the case where the reserved user is present in the + # SELECT, meaning that a legitmate mau is deleted. + sql = """ + DELETE FROM monthly_active_users + WHERE user_id NOT IN ( + SELECT user_id FROM monthly_active_users + WHERE user_id NOT IN ({}) + ORDER BY timestamp DESC + LIMIT ? + ) + AND user_id NOT IN ({}) + """.format( + question_marks, question_marks + ) + + query_args = [ + *reserved_users, + num_of_non_reserved_users_to_remove, + *reserved_users, + ] - yield self.runInteraction("reap_monthly_active_users", _reap_users) + txn.execute(sql, query_args) + + reserved_users = yield self.get_registered_reserved_users() + yield self.runInteraction( + "reap_monthly_active_users", _reap_users, reserved_users + ) # It seems poor to invalidate the whole cache, Postgres supports # 'Returning' which would allow me to invalidate only the # specific users, but sqlite has no way to do this and instead @@ -159,21 +178,25 @@ class MonthlyActiveUsersStore(SQLBaseStore): return self.runInteraction("count_users", _count_users) @defer.inlineCallbacks - def get_registered_reserved_users_count(self): - """Of the reserved threepids defined in config, how many are associated + def get_registered_reserved_users(self): + """Of the reserved threepids defined in config, which are associated with registered users? Returns: - Defered[int]: Number of real reserved users + Defered[list]: Real reserved users """ - count = 0 - for tp in self.hs.config.mau_limits_reserved_threepids: + users = [] + + for tp in self.hs.config.mau_limits_reserved_threepids[ + : self.hs.config.max_mau_value + ]: user_id = yield self.hs.get_datastore().get_user_id_by_threepid( tp["medium"], tp["address"] ) if user_id: - count = count + 1 - return count + users.append(user_id) + + return users @defer.inlineCallbacks def upsert_monthly_active_user(self, user_id): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 5db6f2d84a..3a641f538b 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -18,11 +18,10 @@ from collections import namedtuple from twisted.internet import defer from synapse.api.constants import PresenceState +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedList -from ._base import SQLBaseStore - class UserPresenceState( namedtuple( @@ -119,14 +118,13 @@ class PresenceStore(SQLBaseStore): ) # Delete old rows to stop database from getting really big - sql = ( - "DELETE FROM presence_stream WHERE" " stream_id < ?" " AND user_id IN (%s)" - ) + sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " for states in batch_iter(presence_states, 50): - args = [stream_id] - args.extend(s.user_id for s in states) - txn.execute(sql % (",".join("?" for _ in states),), args) + clause, args = make_in_list_sql_clause( + self.database_engine, "user_id", [s.user_id for s in states] + ) + txn.execute(sql + clause, [stream_id] + list(args)) def get_all_presence_updates(self, last_id, current_id): if last_id == current_id: diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index a6517c4cf3..c4e24edff2 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -183,8 +183,8 @@ class PushRulesWorkerStore( return results @defer.inlineCallbacks - def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule): - """Move a single push rule from one room to another for a specific user. + def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule): + """Copy a single push rule from one room to another for a specific user. Args: new_room_id (str): ID of the new room. @@ -209,14 +209,11 @@ class PushRulesWorkerStore( actions=rule["actions"], ) - # Delete push rule for the old room - yield self.delete_push_rule(user_id, rule["rule_id"]) - @defer.inlineCallbacks - def move_push_rules_from_room_to_room_for_user( + def copy_push_rules_from_room_to_room_for_user( self, old_room_id, new_room_id, user_id ): - """Move all of the push rules from one room to another for a specific + """Copy all of the push rules from one room to another for a specific user. Args: @@ -227,15 +224,14 @@ class PushRulesWorkerStore( # Retrieve push rules for this user user_push_rules = yield self.get_push_rules_for_user(user_id) - # Get rules relating to the old room, move them to the new room, then - # delete them from the old room + # Get rules relating to the old room and copy them to the new room for rule in user_push_rules: conditions = rule.get("conditions", []) if any( (c.get("key") == "room_id" and c.get("pattern") == old_room_id) for c in conditions ): - self.move_push_rule_from_room_to_room(new_room_id, user_id, rule) + yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) @defer.inlineCallbacks def bulk_get_push_rules_for_room(self, event, context): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 3e0e834a62..b12e80440a 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -241,7 +241,7 @@ class PusherStore(PusherWorkerStore): "device_display_name": device_display_name, "ts": pushkey_ts, "lang": lang, - "data": encode_canonical_json(data), + "data": bytearray(encode_canonical_json(data)), "last_stream_ordering": last_stream_ordering, "profile_tag": profile_tag, "id": stream_id, diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 290ddb30e8..0c24430f28 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -21,12 +21,11 @@ from canonicaljson import json from twisted.internet import defer +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import SQLBaseStore -from .util.id_generators import StreamIdGenerator - logger = logging.getLogger(__name__) @@ -217,24 +216,26 @@ class ReceiptsWorkerStore(SQLBaseStore): def f(txn): if from_key: - sql = ( - "SELECT * FROM receipts_linearized WHERE" - " room_id IN (%s) AND stream_id > ? AND stream_id <= ?" - ) % (",".join(["?"] * len(room_ids))) - args = list(room_ids) - args.extend([from_key, to_key]) + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id > ? AND stream_id <= ? AND + """ + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) - txn.execute(sql, args) + txn.execute(sql + clause, [from_key, to_key] + list(args)) else: - sql = ( - "SELECT * FROM receipts_linearized WHERE" - " room_id IN (%s) AND stream_id <= ?" - ) % (",".join(["?"] * len(room_ids))) + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id <= ? AND + """ - args = list(room_ids) - args.append(to_key) + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) - txn.execute(sql, args) + txn.execute(sql + clause, [to_key] + list(args)) return self.cursor_to_dict(txn) @@ -433,13 +434,19 @@ class ReceiptsStore(ReceiptsWorkerStore): # we need to points in graph -> linearized form. # TODO: Make this better. def graph_to_linear(txn): - query = ( - "SELECT event_id WHERE room_id = ? AND stream_ordering IN (" - " SELECT max(stream_ordering) WHERE event_id IN (%s)" - ")" - ) % (",".join(["?"] * len(event_ids))) + clause, args = make_in_list_sql_clause( + self.database_engine, "event_id", event_ids + ) + + sql = """ + SELECT event_id WHERE room_id = ? AND stream_ordering IN ( + SELECT max(stream_ordering) WHERE %s + ) + """ % ( + clause, + ) - txn.execute(query, [room_id] + event_ids) + txn.execute(sql, [room_id] + list(args)) rows = txn.fetchall() if rows: return rows[0][0] diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 241a7be51e..6c5b29288a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -493,7 +493,9 @@ class RegistrationWorkerStore(SQLBaseStore): """ def _find_next_generated_user_id(txn): - txn.execute("SELECT name FROM users") + # We bound between '@1' and '@a' to avoid pulling the entire table + # out. + txn.execute("SELECT name FROM users WHERE '@1' <= name AND name < '@a'") regex = re.compile(r"^@(\d+):") @@ -785,13 +787,14 @@ class RegistrationWorkerStore(SQLBaseStore): ) -class RegistrationStore( +class RegistrationBackgroundUpdateStore( RegistrationWorkerStore, background_updates.BackgroundUpdateStore ): def __init__(self, db_conn, hs): - super(RegistrationStore, self).__init__(db_conn, hs) + super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs) self.clock = hs.get_clock() + self.config = hs.config self.register_background_index_update( "access_tokens_device_index", @@ -807,8 +810,6 @@ class RegistrationStore( columns=["creation_ts"], ) - self._account_validity = hs.config.account_validity - # we no longer use refresh tokens, but it's possible that some people # might have a background update queued to build this index. Just # clear the background update. @@ -822,17 +823,6 @@ class RegistrationStore( "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) - # Create a background job for culling expired 3PID validity tokens - def start_cull(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "cull_expired_threepid_validation_tokens", - self.cull_expired_threepid_validation_tokens, - ) - - hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) - @defer.inlineCallbacks def _background_update_set_deactivated_flag(self, progress, batch_size): """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 @@ -895,6 +885,54 @@ class RegistrationStore( return nb_processed @defer.inlineCallbacks + def _bg_user_threepids_grandfather(self, progress, batch_size): + """We now track which identity servers a user binds their 3PID to, so + we need to handle the case of existing bindings where we didn't track + this. + + We do this by grandfathering in existing user threepids assuming that + they used one of the server configured trusted identity servers. + """ + id_servers = set(self.config.trusted_third_party_id_servers) + + def _bg_user_threepids_grandfather_txn(txn): + sql = """ + INSERT INTO user_threepid_id_server + (user_id, medium, address, id_server) + SELECT user_id, medium, address, ? + FROM user_threepids + """ + + txn.executemany(sql, [(id_server,) for id_server in id_servers]) + + if id_servers: + yield self.runInteraction( + "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn + ) + + yield self._end_background_update("user_threepids_grandfather") + + return 1 + + +class RegistrationStore(RegistrationBackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(RegistrationStore, self).__init__(db_conn, hs) + + self._account_validity = hs.config.account_validity + + # Create a background job for culling expired 3PID validity tokens + def start_cull(): + # run as a background process to make sure that the database transactions + # have a logcontext to report to + return run_as_background_process( + "cull_expired_threepid_validation_tokens", + self.cull_expired_threepid_validation_tokens, + ) + + hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) + + @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): """Adds an access token for the given user. @@ -1242,36 +1280,6 @@ class RegistrationStore( desc="get_users_pending_deactivation", ) - @defer.inlineCallbacks - def _bg_user_threepids_grandfather(self, progress, batch_size): - """We now track which identity servers a user binds their 3PID to, so - we need to handle the case of existing bindings where we didn't track - this. - - We do this by grandfathering in existing user threepids assuming that - they used one of the server configured trusted identity servers. - """ - id_servers = set(self.config.trusted_third_party_id_servers) - - def _bg_user_threepids_grandfather_txn(txn): - sql = """ - INSERT INTO user_threepid_id_server - (user_id, medium, address, id_server) - SELECT user_id, medium, address, ? - FROM user_threepids - """ - - txn.executemany(sql, [(id_server,) for id_server in id_servers]) - - if id_servers: - yield self.runInteraction( - "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn - ) - - yield self._end_background_update("user_threepids_grandfather") - - return 1 - def validate_threepid_session(self, session_id, client_secret, token, current_ts): """Attempt to validate a threepid session using a token @@ -1463,17 +1471,6 @@ class RegistrationStore( self.clock.time_msec(), ) - def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self._simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"deactivated": 1 if deactivated else 0}, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_deactivated_status, (user_id,) - ) - @defer.inlineCallbacks def set_user_deactivated_status(self, user_id, deactivated): """Set the `deactivated` property for the provided user to the provided value. @@ -1489,3 +1486,14 @@ class RegistrationStore( user_id, deactivated, ) + + def set_user_deactivated_status_txn(self, txn, user_id, deactivated): + self._simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"deactivated": 1 if deactivated else 0}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_deactivated_status, (user_id,) + ) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 08e13f3a3b..43cc56fa6f 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +17,7 @@ import collections import logging import re +from typing import Optional, Tuple from canonicaljson import json @@ -24,6 +26,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.search import SearchStore +from synapse.types import ThirdPartyInstanceID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -63,103 +66,196 @@ class RoomWorkerStore(SQLBaseStore): desc="get_public_room_ids", ) - @cached(num_args=2, max_entries=100) - def get_public_room_ids_at_stream_id(self, stream_id, network_tuple): - """Get pulbic rooms for a particular list, or across all lists. + def count_public_rooms(self, network_tuple, ignore_non_federatable): + """Counts the number of public rooms as tracked in the room_stats_current + and room_stats_state table. Args: - stream_id (int) - network_tuple (ThirdPartyInstanceID): The list to use (None, None) - means the main list, None means all lsits. + network_tuple (ThirdPartyInstanceID|None) + ignore_non_federatable (bool): If true filters out non-federatable rooms """ - return self.runInteraction( - "get_public_room_ids_at_stream_id", - self.get_public_room_ids_at_stream_id_txn, - stream_id, - network_tuple=network_tuple, - ) - - def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple): - return { - rm - for rm, vis in self.get_published_at_stream_id_txn( - txn, stream_id, network_tuple=network_tuple - ).items() - if vis - } - def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple): - if network_tuple: - # We want to get from a particular list. No aggregation required. + def _count_public_rooms_txn(txn): + query_args = [] + + if network_tuple: + if network_tuple.appservice_id: + published_sql = """ + SELECT room_id from appservice_room_list + WHERE appservice_id = ? AND network_id = ? + """ + query_args.append(network_tuple.appservice_id) + query_args.append(network_tuple.network_id) + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + """ + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + UNION SELECT room_id from appservice_room_list + """ sql = """ - SELECT room_id, visibility FROM public_room_list_stream - INNER JOIN ( - SELECT room_id, max(stream_id) AS stream_id - FROM public_room_list_stream - WHERE stream_id <= ? %s - GROUP BY room_id - ) grouped USING (room_id, stream_id) + SELECT + COALESCE(COUNT(*), 0) + FROM ( + %(published_sql)s + ) published + INNER JOIN room_stats_state USING (room_id) + INNER JOIN room_stats_current USING (room_id) + WHERE + ( + join_rules = 'public' OR history_visibility = 'world_readable' + ) + AND joined_members > 0 + """ % { + "published_sql": published_sql + } + + txn.execute(sql, query_args) + return txn.fetchone()[0] + + return self.runInteraction("count_public_rooms", _count_public_rooms_txn) + + @defer.inlineCallbacks + def get_largest_public_rooms( + self, + network_tuple: Optional[ThirdPartyInstanceID], + search_filter: Optional[dict], + limit: Optional[int], + bounds: Optional[Tuple[int, str]], + forwards: bool, + ignore_non_federatable: bool = False, + ): + """Gets the largest public rooms (where largest is in terms of joined + members, as tracked in the statistics table). + + Args: + network_tuple + search_filter + limit: Maxmimum number of rows to return, unlimited otherwise. + bounds: An uppoer or lower bound to apply to result set if given, + consists of a joined member count and room_id (these are + excluded from result set). + forwards: true iff going forwards, going backwards otherwise + ignore_non_federatable: If true filters out non-federatable rooms. + + Returns: + Rooms in order: biggest number of joined users first. + We then arbitrarily use the room_id as a tie breaker. + + """ + + where_clauses = [] + query_args = [] + + if network_tuple: + if network_tuple.appservice_id: + published_sql = """ + SELECT room_id from appservice_room_list + WHERE appservice_id = ? AND network_id = ? + """ + query_args.append(network_tuple.appservice_id) + query_args.append(network_tuple.network_id) + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + """ + else: + published_sql = """ + SELECT room_id FROM rooms WHERE is_public + UNION SELECT room_id from appservice_room_list """ - if network_tuple.appservice_id is not None: - txn.execute( - sql % ("AND appservice_id = ? AND network_id = ?",), - (stream_id, network_tuple.appservice_id, network_tuple.network_id), + # Work out the bounds if we're given them, these bounds look slightly + # odd, but are designed to help query planner use indices by pulling + # out a common bound. + if bounds: + last_joined_members, last_room_id = bounds + if forwards: + where_clauses.append( + """ + joined_members <= ? AND ( + joined_members < ? OR room_id < ? + ) + """ ) else: - txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,)) - return dict(txn) - else: - # We want to get from all lists, so we need to aggregate the results + where_clauses.append( + """ + joined_members >= ? AND ( + joined_members > ? OR room_id > ? + ) + """ + ) - logger.info("Executing full list") + query_args += [last_joined_members, last_joined_members, last_room_id] - sql = """ - SELECT room_id, visibility - FROM public_room_list_stream - INNER JOIN ( - SELECT - room_id, max(stream_id) AS stream_id, appservice_id, - network_id - FROM public_room_list_stream - WHERE stream_id <= ? - GROUP BY room_id, appservice_id, network_id - ) grouped USING (room_id, stream_id) - """ + if ignore_non_federatable: + where_clauses.append("is_federatable") - txn.execute(sql, (stream_id,)) + if search_filter and search_filter.get("generic_search_term", None): + search_term = "%" + search_filter["generic_search_term"] + "%" - results = {} - # A room is visible if its visible on any list. - for room_id, visibility in txn: - results[room_id] = bool(visibility) or results.get(room_id, False) + where_clauses.append( + """ + ( + name LIKE ? + OR topic LIKE ? + OR canonical_alias LIKE ? + ) + """ + ) + query_args += [search_term, search_term, search_term] + + where_clause = "" + if where_clauses: + where_clause = " AND " + " AND ".join(where_clauses) + + sql = """ + SELECT + room_id, name, topic, canonical_alias, joined_members, + avatar, history_visibility, joined_members, guest_access + FROM ( + %(published_sql)s + ) published + INNER JOIN room_stats_state USING (room_id) + INNER JOIN room_stats_current USING (room_id) + WHERE + ( + join_rules = 'public' OR history_visibility = 'world_readable' + ) + AND joined_members > 0 + %(where_clause)s + ORDER BY joined_members %(dir)s, room_id %(dir)s + """ % { + "published_sql": published_sql, + "where_clause": where_clause, + "dir": "DESC" if forwards else "ASC", + } - return results + if limit is not None: + query_args.append(limit) - def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple): - def get_public_room_changes_txn(txn): - then_rooms = self.get_public_room_ids_at_stream_id_txn( - txn, prev_stream_id, network_tuple - ) + sql += """ + LIMIT ? + """ - now_rooms_dict = self.get_published_at_stream_id_txn( - txn, new_stream_id, network_tuple - ) + def _get_largest_public_rooms_txn(txn): + txn.execute(sql, query_args) - now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis) - now_rooms_not_visible = set( - rm for rm, vis in now_rooms_dict.items() if not vis - ) + results = self.cursor_to_dict(txn) - newly_visible = now_rooms_visible - then_rooms - newly_unpublished = now_rooms_not_visible & then_rooms + if not forwards: + results.reverse() - return newly_visible, newly_unpublished + return results - return self.runInteraction( - "get_public_room_changes", get_public_room_changes_txn + ret_val = yield self.runInteraction( + "get_largest_public_rooms", _get_largest_public_rooms_txn ) + defer.returnValue(ret_val) @cached(max_entries=10000) def is_room_blocked(self, room_id): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 4df8ebdacd..ff63487823 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -26,13 +26,15 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction +from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.engines import Sqlite3Engine from synapse.storage.events_worker import EventsWorkerStore from synapse.types import get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.metrics import Measure from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) @@ -370,6 +372,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): results = [] if membership_list: if self._current_state_events_membership_up_to_date: + clause, args = make_in_list_sql_clause( + self.database_engine, "c.membership", membership_list + ) sql = """ SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering FROM current_state_events AS c @@ -377,11 +382,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): WHERE c.type = 'm.room.member' AND state_key = ? - AND c.membership IN (%s) + AND %s """ % ( - ",".join("?" * len(membership_list)) + clause, ) else: + clause, args = make_in_list_sql_clause( + self.database_engine, "m.membership", membership_list + ) sql = """ SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering FROM current_state_events AS c @@ -390,12 +398,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): WHERE c.type = 'm.room.member' AND state_key = ? - AND m.membership IN (%s) + AND %s """ % ( - ",".join("?" * len(membership_list)) + clause, ) - txn.execute(sql, (user_id, *membership_list)) + txn.execute(sql, (user_id, *args)) results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] if do_invite: @@ -483,6 +491,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) return result + @defer.inlineCallbacks def get_joined_users_from_state(self, room_id, state_entry): state_group = state_entry.state_group if not state_group: @@ -492,9 +501,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._get_joined_users_from_context( - room_id, state_group, state_entry.state, context=state_entry - ) + with Measure(self._clock, "get_joined_users_from_state"): + return ( + yield self._get_joined_users_from_context( + room_id, state_group, state_entry.state, context=state_entry + ) + ) @cachedInlineCallbacks( num_args=2, cache_context=True, iterable=True, max_entries=100000 @@ -567,25 +579,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): missing_member_event_ids.append(event_id) if missing_member_event_ids: - rows = yield self._simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=missing_member_event_ids, - retcols=("user_id", "display_name", "avatar_url"), - keyvalues={"membership": Membership.JOIN}, - batch_size=500, - desc="_get_joined_users_from_context", - ) - - users_in_room.update( - { - to_ascii(row["user_id"]): ProfileInfo( - avatar_url=to_ascii(row["avatar_url"]), - display_name=to_ascii(row["display_name"]), - ) - for row in rows - } + event_to_memberships = yield self._get_joined_profiles_from_event_ids( + missing_member_event_ids ) + users_in_room.update((row for row in event_to_memberships.values() if row)) if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: @@ -597,6 +594,47 @@ class RoomMemberWorkerStore(EventsWorkerStore): return users_in_room + @cached(max_entries=10000) + def _get_joined_profile_from_event_id(self, event_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_joined_profile_from_event_id", + list_name="event_ids", + inlineCallbacks=True, + ) + def _get_joined_profiles_from_event_ids(self, event_ids): + """For given set of member event_ids check if they point to a join + event and if so return the associated user and profile info. + + Args: + event_ids (Iterable[str]): The member event IDs to lookup + + Returns: + Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID + to `user_id` and ProfileInfo (or None if not join event). + """ + + rows = yield self._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=event_ids, + retcols=("user_id", "display_name", "avatar_url", "event_id"), + keyvalues={"membership": Membership.JOIN}, + batch_size=500, + desc="_get_membership_from_event_ids", + ) + + return { + row["event_id"]: ( + row["user_id"], + ProfileInfo( + avatar_url=row["avatar_url"], display_name=row["display_name"] + ), + ) + for row in rows + } + @cachedInlineCallbacks(max_entries=10000) def is_host_joined(self, room_id, host): if "%" in host or "_" in host: @@ -669,6 +707,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return True + @defer.inlineCallbacks def get_joined_hosts(self, room_id, state_entry): state_group = state_entry.state_group if not state_group: @@ -678,9 +717,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._get_joined_hosts( - room_id, state_group, state_entry.state, state_entry=state_entry - ) + with Measure(self._clock, "get_joined_hosts"): + return ( + yield self._get_joined_hosts( + room_id, state_group, state_entry.state, state_entry=state_entry + ) + ) @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) # @defer.inlineCallbacks @@ -785,9 +827,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) -class RoomMemberStore(RoomMemberWorkerStore): +class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): def __init__(self, db_conn, hs): - super(RoomMemberStore, self).__init__(db_conn, hs) + super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs) self.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) @@ -803,112 +845,6 @@ class RoomMemberStore(RoomMemberWorkerStore): where_clause="forgotten = 1", ) - def _store_room_members_txn(self, txn, events, backfilled): - """Store a room member in the database. - """ - self._simple_insert_many_txn( - txn, - table="room_memberships", - values=[ - { - "event_id": event.event_id, - "user_id": event.state_key, - "sender": event.user_id, - "room_id": event.room_id, - "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), - } - for event in events - ], - ) - - for event in events: - txn.call_after( - self._membership_stream_cache.entity_has_changed, - event.state_key, - event.internal_metadata.stream_ordering, - ) - txn.call_after( - self.get_invited_rooms_for_user.invalidate, (event.state_key,) - ) - - # We update the local_invites table only if the event is "current", - # i.e., its something that has just happened. If the event is an - # outlier it is only current if its an "out of band membership", - # like a remote invite or a rejection of a remote invite. - is_new_state = not backfilled and ( - not event.internal_metadata.is_outlier() - or event.internal_metadata.is_out_of_band_membership() - ) - is_mine = self.hs.is_mine_id(event.state_key) - if is_new_state and is_mine: - if event.membership == Membership.INVITE: - self._simple_insert_txn( - txn, - table="local_invites", - values={ - "event_id": event.event_id, - "invitee": event.state_key, - "inviter": event.sender, - "room_id": event.room_id, - "stream_id": event.internal_metadata.stream_ordering, - }, - ) - else: - sql = ( - "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - txn.execute( - sql, - ( - event.internal_metadata.stream_ordering, - event.event_id, - event.room_id, - event.state_key, - ), - ) - - @defer.inlineCallbacks - def locally_reject_invite(self, user_id, room_id): - sql = ( - "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - def f(txn, stream_ordering): - txn.execute(sql, (stream_ordering, True, room_id, user_id)) - - with self._stream_id_gen.get_next() as stream_ordering: - yield self.runInteraction("locally_reject_invite", f, stream_ordering) - - def forget(self, user_id, room_id): - """Indicate that user_id wishes to discard history for room_id.""" - - def f(txn): - sql = ( - "UPDATE" - " room_memberships" - " SET" - " forgotten = 1" - " WHERE" - " user_id = ?" - " AND" - " room_id = ?" - ) - txn.execute(sql, (user_id, room_id)) - - self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) - self._invalidate_cache_and_stream( - txn, self.get_forgotten_rooms_for_user, (user_id,) - ) - - return self.runInteraction("forget_membership", f) - @defer.inlineCallbacks def _background_add_membership_profile(self, progress, batch_size): target_min_stream_id = progress.get( @@ -1043,6 +979,117 @@ class RoomMemberStore(RoomMemberWorkerStore): return row_count +class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(RoomMemberStore, self).__init__(db_conn, hs) + + def _store_room_members_txn(self, txn, events, backfilled): + """Store a room member in the database. + """ + self._simple_insert_many_txn( + txn, + table="room_memberships", + values=[ + { + "event_id": event.event_id, + "user_id": event.state_key, + "sender": event.user_id, + "room_id": event.room_id, + "membership": event.membership, + "display_name": event.content.get("displayname", None), + "avatar_url": event.content.get("avatar_url", None), + } + for event in events + ], + ) + + for event in events: + txn.call_after( + self._membership_stream_cache.entity_has_changed, + event.state_key, + event.internal_metadata.stream_ordering, + ) + txn.call_after( + self.get_invited_rooms_for_user.invalidate, (event.state_key,) + ) + + # We update the local_invites table only if the event is "current", + # i.e., its something that has just happened. If the event is an + # outlier it is only current if its an "out of band membership", + # like a remote invite or a rejection of a remote invite. + is_new_state = not backfilled and ( + not event.internal_metadata.is_outlier() + or event.internal_metadata.is_out_of_band_membership() + ) + is_mine = self.hs.is_mine_id(event.state_key) + if is_new_state and is_mine: + if event.membership == Membership.INVITE: + self._simple_insert_txn( + txn, + table="local_invites", + values={ + "event_id": event.event_id, + "invitee": event.state_key, + "inviter": event.sender, + "room_id": event.room_id, + "stream_id": event.internal_metadata.stream_ordering, + }, + ) + else: + sql = ( + "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute( + sql, + ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + ), + ) + + @defer.inlineCallbacks + def locally_reject_invite(self, user_id, room_id): + sql = ( + "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + def f(txn, stream_ordering): + txn.execute(sql, (stream_ordering, True, room_id, user_id)) + + with self._stream_id_gen.get_next() as stream_ordering: + yield self.runInteraction("locally_reject_invite", f, stream_ordering) + + def forget(self, user_id, room_id): + """Indicate that user_id wishes to discard history for room_id.""" + + def f(txn): + sql = ( + "UPDATE" + " room_memberships" + " SET" + " forgotten = 1" + " WHERE" + " user_id = ?" + " AND" + " room_id = ?" + ) + txn.execute(sql, (user_id, room_id)) + + self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) + self._invalidate_cache_and_stream( + txn, self.get_forgotten_rooms_for_user, (user_id,) + ) + + return self.runInteraction("forget_membership", f) + + class _JoinedHostsCache(object): """Cache for joined hosts in a room that is optimised to handle updates via state deltas. diff --git a/synapse/storage/schema/delta/56/drop_unused_event_tables.sql b/synapse/storage/schema/delta/56/drop_unused_event_tables.sql new file mode 100644 index 0000000000..9f09922c67 --- /dev/null +++ b/synapse/storage/schema/delta/56/drop_unused_event_tables.sql @@ -0,0 +1,20 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- these tables are never used. +DROP TABLE IF EXISTS room_names; +DROP TABLE IF EXISTS topics; +DROP TABLE IF EXISTS history_visibility; +DROP TABLE IF EXISTS guest_access; diff --git a/synapse/storage/schema/delta/56/public_room_list_idx.sql b/synapse/storage/schema/delta/56/public_room_list_idx.sql new file mode 100644 index 0000000000..7be31ffebb --- /dev/null +++ b/synapse/storage/schema/delta/56/public_room_list_idx.sql @@ -0,0 +1,16 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX public_room_list_stream_network ON public_room_list_stream (appservice_id, network_id, room_id); diff --git a/synapse/storage/schema/delta/56/unique_user_filter_index.py b/synapse/storage/schema/delta/56/unique_user_filter_index.py new file mode 100644 index 0000000000..1de8b54961 --- /dev/null +++ b/synapse/storage/schema/delta/56/unique_user_filter_index.py @@ -0,0 +1,52 @@ +import logging + +from synapse.storage.engines import PostgresEngine + +logger = logging.getLogger(__name__) + + +""" +This migration updates the user_filters table as follows: + + - drops any (user_id, filter_id) duplicates + - makes the columns NON-NULLable + - turns the index into a UNIQUE index +""" + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + select_clause = """ + SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json + FROM user_filters + """ + else: + select_clause = """ + SELECT * FROM user_filters GROUP BY user_id, filter_id + """ + sql = """ + DROP TABLE IF EXISTS user_filters_migration; + DROP INDEX IF EXISTS user_filters_unique; + CREATE TABLE user_filters_migration ( + user_id TEXT NOT NULL, + filter_id BIGINT NOT NULL, + filter_json BYTEA NOT NULL + ); + INSERT INTO user_filters_migration (user_id, filter_id, filter_json) + %s; + CREATE UNIQUE INDEX user_filters_unique ON user_filters_migration + (user_id, filter_id); + DROP TABLE user_filters; + ALTER TABLE user_filters_migration RENAME TO user_filters; + """ % ( + select_clause, + ) + + if isinstance(database_engine, PostgresEngine): + cur.execute(sql) + else: + cur.executescript(sql) diff --git a/synapse/storage/search.py b/synapse/storage/search.py index df87ab6a6d..7695bf09fc 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -24,6 +24,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.errors import SynapseError +from synapse.storage._base import make_in_list_sql_clause from synapse.storage.engines import PostgresEngine, Sqlite3Engine from .background_updates import BackgroundUpdateStore @@ -36,7 +37,7 @@ SearchEntry = namedtuple( ) -class SearchStore(BackgroundUpdateStore): +class SearchBackgroundUpdateStore(BackgroundUpdateStore): EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" @@ -44,7 +45,7 @@ class SearchStore(BackgroundUpdateStore): EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" def __init__(self, db_conn, hs): - super(SearchStore, self).__init__(db_conn, hs) + super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs) if not hs.config.enable_search: return @@ -289,29 +290,6 @@ class SearchStore(BackgroundUpdateStore): return num_rows - def store_event_search_txn(self, txn, event, key, value): - """Add event to the search table - - Args: - txn (cursor): - event (EventBase): - key (str): - value (str): - """ - self.store_search_entries_txn( - txn, - ( - SearchEntry( - key=key, - value=value, - event_id=event.event_id, - room_id=event.room_id, - stream_ordering=event.internal_metadata.stream_ordering, - origin_server_ts=event.origin_server_ts, - ), - ), - ) - def store_search_entries_txn(self, txn, entries): """Add entries to the search table @@ -358,6 +336,34 @@ class SearchStore(BackgroundUpdateStore): # This should be unreachable. raise Exception("Unrecognized database engine") + +class SearchStore(SearchBackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(SearchStore, self).__init__(db_conn, hs) + + def store_event_search_txn(self, txn, event, key, value): + """Add event to the search table + + Args: + txn (cursor): + event (EventBase): + key (str): + value (str): + """ + self.store_search_entries_txn( + txn, + ( + SearchEntry( + key=key, + value=value, + event_id=event.event_id, + room_id=event.room_id, + stream_ordering=event.internal_metadata.stream_ordering, + origin_server_ts=event.origin_server_ts, + ), + ), + ) + @defer.inlineCallbacks def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. @@ -380,8 +386,10 @@ class SearchStore(BackgroundUpdateStore): # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. if len(room_ids) < 500: - clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)) - args.extend(room_ids) + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + clauses = [clause] local_clauses = [] for key in keys: @@ -487,8 +495,10 @@ class SearchStore(BackgroundUpdateStore): # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. if len(room_ids) < 500: - clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)) - args.extend(room_ids) + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + clauses = [clause] local_clauses = [] for key in keys: diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 1980a87108..a941a5ae3f 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -353,8 +353,158 @@ class StateFilter(object): return member_filter, non_member_filter +class StateGroupBackgroundUpdateStore(SQLBaseStore): + """Defines functions related to state groups needed to run the state backgroud + updates. + """ + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """ + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + + def _get_state_groups_from_groups_txn( + self, txn, groups, state_filter=StateFilter.all() + ): + results = {group: {} for group in groups} + + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + + if isinstance(self.database_engine, PostgresEngine): + # Temporarily disable sequential scans in this transaction. This is + # a temporary hack until we can add the right indices in + txn.execute("SET LOCAL enable_seqscan=off") + + # The below query walks the state_group tree so that the "state" + # table includes all state_groups in the tree. It then joins + # against `state_groups_state` to fetch the latest state. + # It assumes that previous state groups are always numerically + # lesser. + # The PARTITION is used to get the event_id in the greatest state + # group for the given type, state_key. + # This may return multiple rows per (type, state_key), but last_value + # should be the same. + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT DISTINCT type, state_key, last_value(event_id) OVER ( + PARTITION BY type, state_key ORDER BY state_group ASC + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS event_id FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM state + ) + """ + + for group in groups: + args = [group] + args.extend(where_args) + + txn.execute(sql + where_clause, args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id + else: + max_entries_returned = state_filter.max_entries_returned() + + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + for group in groups: + next_group = group + + while next_group: + # We did this before by getting the list of group ids, and + # then passing that list to sqlite to get latest event for + # each (type, state_key). However, that was terribly slow + # without the right indices (which we can't add until + # after we finish deduping state, which requires this func) + args = [next_group] + args.extend(where_args) + + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ? " + where_clause, + args, + ) + results[group].update( + ((typ, state_key), event_id) + for typ, state_key, event_id in txn + if (typ, state_key) not in results[group] + ) + + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + max_entries_returned is not None + and len(results[group]) == max_entries_returned + ): + break + + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + + return results + + # this inherits from EventsWorkerStore because it calls self.get_events -class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): +class StateGroupWorkerStore( + EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore +): """The parts of StateGroupStore that can be called from workers. """ @@ -694,107 +844,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return results - def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all() - ): - results = {group: {} for group in groups} - - where_clause, where_args = state_filter.make_sql_filter_clause() - - # Unless the filter clause is empty, we're going to append it after an - # existing where clause - if where_clause: - where_clause = " AND (%s)" % (where_clause,) - - if isinstance(self.database_engine, PostgresEngine): - # Temporarily disable sequential scans in this transaction. This is - # a temporary hack until we can add the right indices in - txn.execute("SET LOCAL enable_seqscan=off") - - # The below query walks the state_group tree so that the "state" - # table includes all state_groups in the tree. It then joins - # against `state_groups_state` to fetch the latest state. - # It assumes that previous state groups are always numerically - # lesser. - # The PARTITION is used to get the event_id in the greatest state - # group for the given type, state_key. - # This may return multiple rows per (type, state_key), but last_value - # should be the same. - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT DISTINCT type, state_key, last_value(event_id) OVER ( - PARTITION BY type, state_key ORDER BY state_group ASC - ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - ) AS event_id FROM state_groups_state - WHERE state_group IN ( - SELECT state_group FROM state - ) - """ - - for group in groups: - args = [group] - args.extend(where_args) - - txn.execute(sql + where_clause, args) - for row in txn: - typ, state_key, event_id = row - key = (typ, state_key) - results[group][key] = event_id - else: - max_entries_returned = state_filter.max_entries_returned() - - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - for group in groups: - next_group = group - - while next_group: - # We did this before by getting the list of group ids, and - # then passing that list to sqlite to get latest event for - # each (type, state_key). However, that was terribly slow - # without the right indices (which we can't add until - # after we finish deduping state, which requires this func) - args = [next_group] - args.extend(where_args) - - txn.execute( - "SELECT type, state_key, event_id FROM state_groups_state" - " WHERE state_group = ? " + where_clause, - args, - ) - results[group].update( - ((typ, state_key), event_id) - for typ, state_key, event_id in txn - if (typ, state_key) not in results[group] - ) - - # If the number of entries in the (type,state_key)->event_id dict - # matches the number of (type,state_keys) types we were searching - # for, then we must have found them all, so no need to go walk - # further down the tree... UNLESS our types filter contained - # wildcards (i.e. Nones) in which case we have to do an exhaustive - # search - if ( - max_entries_returned is not None - and len(results[group]) == max_entries_returned - ): - break - - next_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - - return results - @defer.inlineCallbacks def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): """Given a list of event_ids and type tuples, return a list of state @@ -1238,66 +1287,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return self.runInteraction("store_state_group", _store_state_group_txn) - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """ - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - - -class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): - """ Keeps track of the state at a given event. - - This is done by the concept of `state groups`. Every event is a assigned - a state group (identified by an arbitrary string), which references a - collection of state events. The current state of an event is then the - collection of state events referenced by the event's state group. - Hence, every change in the current state causes a new state group to be - generated. However, if no change happens (e.g., if we get a message event - with only one parent it inherits the state group from its parent.) - - There are three tables: - * `state_groups`: Stores group name, first event with in the group and - room id. - * `event_to_state_groups`: Maps events to state groups. - * `state_groups_state`: Maps state group to state events. - """ +class StateBackgroundUpdateStore( + StateGroupBackgroundUpdateStore, BackgroundUpdateStore +): STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" @@ -1305,7 +1298,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" def __init__(self, db_conn, hs): - super(StateStore, self).__init__(db_conn, hs) + super(StateBackgroundUpdateStore, self).__init__(db_conn, hs) self.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, @@ -1327,34 +1320,6 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): columns=["state_group"], ) - def _store_event_state_mappings_txn(self, txn, events_and_contexts): - state_groups = {} - for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue - - # if the event was rejected, just give it the same state as its - # predecessor. - if context.rejected: - state_groups[event.event_id] = context.prev_group - continue - - state_groups[event.event_id] = context.state_group - - self._simple_insert_many_txn( - txn, - table="event_to_state_groups", - values=[ - {"state_group": state_group_id, "event_id": event_id} - for event_id, state_group_id in iteritems(state_groups) - ], - ) - - for event_id, state_group_id in iteritems(state_groups): - txn.call_after( - self._get_state_group_for_event.prefill, (event_id,), state_group_id - ) - @defer.inlineCallbacks def _background_deduplicate_state(self, progress, batch_size): """This background update will slowly deduplicate state by reencoding @@ -1527,3 +1492,54 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) return 1 + + +class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): + """ Keeps track of the state at a given event. + + This is done by the concept of `state groups`. Every event is a assigned + a state group (identified by an arbitrary string), which references a + collection of state events. The current state of an event is then the + collection of state events referenced by the event's state group. + + Hence, every change in the current state causes a new state group to be + generated. However, if no change happens (e.g., if we get a message event + with only one parent it inherits the state group from its parent.) + + There are three tables: + * `state_groups`: Stores group name, first event with in the group and + room id. + * `event_to_state_groups`: Maps events to state groups. + * `state_groups_state`: Maps state group to state events. + """ + + def __init__(self, db_conn, hs): + super(StateStore, self).__init__(db_conn, hs) + + def _store_event_state_mappings_txn(self, txn, events_and_contexts): + state_groups = {} + for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.prev_group + continue + + state_groups[event.event_id] = context.state_group + + self._simple_insert_many_txn( + txn, + table="event_to_state_groups", + values=[ + {"state_group": state_group_id, "event_id": event_id} + for event_id, state_group_id in iteritems(state_groups) + ], + ) + + for event_id, state_group_id in iteritems(state_groups): + txn.call_after( + self._get_state_group_for_event.prefill, (event_id,), state_group_id + ) diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py index 5fdb442104..28f33ec18f 100644 --- a/synapse/storage/state_deltas.py +++ b/synapse/storage/state_deltas.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) class StateDeltasStore(SQLBaseStore): - def get_current_state_deltas(self, prev_stream_id): + def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int): """Fetch a list of room state changes since the given stream id Each entry in the result contains the following fields: @@ -36,15 +36,27 @@ class StateDeltasStore(SQLBaseStore): Args: prev_stream_id (int): point to get changes since (exclusive) + max_stream_id (int): the point that we know has been correctly persisted + - ie, an upper limit to return changes from. Returns: - Deferred[list[dict]]: results + Deferred[tuple[int, list[dict]]: A tuple consisting of: + - the stream id which these results go up to + - list of current_state_delta_stream rows. If it is empty, we are + up to date. """ prev_stream_id = int(prev_stream_id) + + # check we're not going backwards + assert prev_stream_id <= max_stream_id + if not self._curr_state_delta_stream_cache.has_any_entity_changed( prev_stream_id ): - return [] + # if the CSDs haven't changed between prev_stream_id and now, we + # know for certain that they haven't changed between prev_stream_id and + # max_stream_id. + return max_stream_id, [] def get_current_state_deltas_txn(txn): # First we calculate the max stream id that will give us less than @@ -54,21 +66,29 @@ class StateDeltasStore(SQLBaseStore): sql = """ SELECT stream_id, count(*) FROM current_state_delta_stream - WHERE stream_id > ? + WHERE stream_id > ? AND stream_id <= ? GROUP BY stream_id ORDER BY stream_id ASC LIMIT 100 """ - txn.execute(sql, (prev_stream_id,)) + txn.execute(sql, (prev_stream_id, max_stream_id)) total = 0 - max_stream_id = prev_stream_id - for max_stream_id, count in txn: + + for stream_id, count in txn: total += count if total > 100: # We arbitarily limit to 100 entries to ensure we don't # select toooo many. + logger.debug( + "Clipping current_state_delta_stream rows to stream_id %i", + stream_id, + ) + clipped_stream_id = stream_id break + else: + # if there's no problem, we may as well go right up to the max_stream_id + clipped_stream_id = max_stream_id # Now actually get the deltas sql = """ @@ -77,8 +97,8 @@ class StateDeltasStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC """ - txn.execute(sql, (prev_stream_id, max_stream_id)) - return self.cursor_to_dict(txn) + txn.execute(sql, (prev_stream_id, clipped_stream_id)) + return clipped_stream_id, self.cursor_to_dict(txn) return self.runInteraction( "get_current_state_deltas", get_current_state_deltas_txn diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py index 09190d684e..7c224cd3d9 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/stats.py @@ -332,6 +332,9 @@ class StatsStore(StateDeltasStore): def _bulk_update_stats_delta_txn(txn): for stats_type, stats_updates in updates.items(): for stats_id, fields in stats_updates.items(): + logger.info( + "Updating %s stats for %s: %s", stats_type, stats_id, fields + ) self._update_stats_delta_txn( txn, ts=ts, diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index b5188d9bee..1b1e4751b9 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -32,14 +32,14 @@ logger = logging.getLogger(__name__) TEMP_TABLE = "_temp_populate_user_directory" -class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): +class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore): # How many records do we calculate before sending it to # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 def __init__(self, db_conn, hs): - super(UserDirectoryStore, self).__init__(db_conn, hs) + super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs) self.server_name = hs.hostname @@ -452,55 +452,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): "update_profile_in_user_dir", _update_profile_in_user_dir_txn ) - def remove_from_user_dir(self, user_id): - def _remove_from_user_dir_txn(txn): - self._simple_delete_txn( - txn, table="user_directory", keyvalues={"user_id": user_id} - ) - self._simple_delete_txn( - txn, table="user_directory_search", keyvalues={"user_id": user_id} - ) - self._simple_delete_txn( - txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} - ) - self._simple_delete_txn( - txn, - table="users_who_share_private_rooms", - keyvalues={"user_id": user_id}, - ) - self._simple_delete_txn( - txn, - table="users_who_share_private_rooms", - keyvalues={"other_user_id": user_id}, - ) - txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - - return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) - - @defer.inlineCallbacks - def get_users_in_dir_due_to_room(self, room_id): - """Get all user_ids that are in the room directory because they're - in the given room_id - """ - user_ids_share_pub = yield self._simple_select_onecol( - table="users_in_public_rooms", - keyvalues={"room_id": room_id}, - retcol="user_id", - desc="get_users_in_dir_due_to_room", - ) - - user_ids_share_priv = yield self._simple_select_onecol( - table="users_who_share_private_rooms", - keyvalues={"room_id": room_id}, - retcol="other_user_id", - desc="get_users_in_dir_due_to_room", - ) - - user_ids = set(user_ids_share_pub) - user_ids.update(user_ids_share_priv) - - return user_ids - def add_users_who_share_private_room(self, room_id, user_id_tuples): """Insert entries into the users_who_share_private_rooms table. The first user should be a local user. @@ -551,6 +502,98 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): "add_users_in_public_rooms", _add_users_in_public_rooms_txn ) + def delete_all_from_user_dir(self): + """Delete the entire user directory + """ + + def _delete_all_from_user_dir_txn(txn): + txn.execute("DELETE FROM user_directory") + txn.execute("DELETE FROM user_directory_search") + txn.execute("DELETE FROM users_in_public_rooms") + txn.execute("DELETE FROM users_who_share_private_rooms") + txn.call_after(self.get_user_in_directory.invalidate_all) + + return self.runInteraction( + "delete_all_from_user_dir", _delete_all_from_user_dir_txn + ) + + @cached() + def get_user_in_directory(self, user_id): + return self._simple_select_one( + table="user_directory", + keyvalues={"user_id": user_id}, + retcols=("display_name", "avatar_url"), + allow_none=True, + desc="get_user_in_directory", + ) + + def update_user_directory_stream_pos(self, stream_id): + return self._simple_update_one( + table="user_directory_stream_pos", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + desc="update_user_directory_stream_pos", + ) + + +class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): + + # How many records do we calculate before sending it to + # add_users_who_share_private_rooms? + SHARE_PRIVATE_WORKING_SET = 500 + + def __init__(self, db_conn, hs): + super(UserDirectoryStore, self).__init__(db_conn, hs) + + def remove_from_user_dir(self, user_id): + def _remove_from_user_dir_txn(txn): + self._simple_delete_txn( + txn, table="user_directory", keyvalues={"user_id": user_id} + ) + self._simple_delete_txn( + txn, table="user_directory_search", keyvalues={"user_id": user_id} + ) + self._simple_delete_txn( + txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} + ) + self._simple_delete_txn( + txn, + table="users_who_share_private_rooms", + keyvalues={"user_id": user_id}, + ) + self._simple_delete_txn( + txn, + table="users_who_share_private_rooms", + keyvalues={"other_user_id": user_id}, + ) + txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) + + return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) + + @defer.inlineCallbacks + def get_users_in_dir_due_to_room(self, room_id): + """Get all user_ids that are in the room directory because they're + in the given room_id + """ + user_ids_share_pub = yield self._simple_select_onecol( + table="users_in_public_rooms", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids_share_priv = yield self._simple_select_onecol( + table="users_who_share_private_rooms", + keyvalues={"room_id": room_id}, + retcol="other_user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids = set(user_ids_share_pub) + user_ids.update(user_ids_share_priv) + + return user_ids + def remove_user_who_share_room(self, user_id, room_id): """ Deletes entries in the users_who_share_*_rooms table. The first @@ -637,31 +680,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): return [room_id for room_id, in rows] - def delete_all_from_user_dir(self): - """Delete the entire user directory - """ - - def _delete_all_from_user_dir_txn(txn): - txn.execute("DELETE FROM user_directory") - txn.execute("DELETE FROM user_directory_search") - txn.execute("DELETE FROM users_in_public_rooms") - txn.execute("DELETE FROM users_who_share_private_rooms") - txn.call_after(self.get_user_in_directory.invalidate_all) - - return self.runInteraction( - "delete_all_from_user_dir", _delete_all_from_user_dir_txn - ) - - @cached() - def get_user_in_directory(self, user_id): - return self._simple_select_one( - table="user_directory", - keyvalues={"user_id": user_id}, - retcols=("display_name", "avatar_url"), - allow_none=True, - desc="get_user_in_directory", - ) - def get_user_directory_stream_pos(self): return self._simple_select_one_onecol( table="user_directory_stream_pos", @@ -670,14 +688,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore): desc="get_user_directory_stream_pos", ) - def update_user_directory_stream_pos(self, stream_id): - return self._simple_update_one( - table="user_directory_stream_pos", - keyvalues={}, - updatevalues={"stream_id": stream_id}, - desc="update_user_directory_stream_pos", - ) - @defer.inlineCallbacks def search_user_dir(self, user_id, search_term, limit): """Searches for users in directory diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py index 05cabc2282..aa4f0da5f0 100644 --- a/synapse/storage/user_erasure_store.py +++ b/synapse/storage/user_erasure_store.py @@ -56,15 +56,15 @@ class UserErasureWorkerStore(SQLBaseStore): # iterate it multiple times, and (b) avoiding duplicates. user_ids = tuple(set(user_ids)) - def _get_erased_users(txn): - txn.execute( - "SELECT user_id FROM erased_users WHERE user_id IN (%s)" - % (",".join("?" * len(user_ids))), - user_ids, - ) - return set(r[0] for r in txn) - - erased_users = yield self.runInteraction("are_users_erased", _get_erased_users) + rows = yield self._simple_select_many_batch( + table="erased_users", + column="user_id", + iterable=user_ids, + retcols=("user_id",), + desc="are_users_erased", + ) + erased_users = set(row["user_id"] for row in rows) + res = dict((u, u in erased_users) for u in user_ids) return res |