diff options
Diffstat (limited to 'synapse/storage')
37 files changed, 1155 insertions, 628 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 6116191b16..0ba3a025cf 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -32,6 +32,7 @@ from typing import ( overload, ) +import attr from prometheus_client import Histogram from typing_extensions import Literal @@ -90,13 +91,17 @@ def make_pool( return adbapi.ConnectionPool( db_config.config["name"], cp_reactor=reactor, - cp_openfun=engine.on_new_connection, + cp_openfun=lambda conn: engine.on_new_connection( + LoggingDatabaseConnection(conn, engine, "on_new_connection") + ), **db_config.config.get("args", {}) ) def make_conn( - db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + db_config: DatabaseConnectionConfig, + engine: BaseDatabaseEngine, + default_txn_name: str, ) -> Connection: """Make a new connection to the database and return it. @@ -109,11 +114,60 @@ def make_conn( for k, v in db_config.config.get("args", {}).items() if not k.startswith("cp_") } - db_conn = engine.module.connect(**db_params) + native_db_conn = engine.module.connect(**db_params) + db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name) + engine.on_new_connection(db_conn) return db_conn +@attr.s(slots=True) +class LoggingDatabaseConnection: + """A wrapper around a database connection that returns `LoggingTransaction` + as its cursor class. + + This is mainly used on startup to ensure that queries get logged correctly + """ + + conn = attr.ib(type=Connection) + engine = attr.ib(type=BaseDatabaseEngine) + default_txn_name = attr.ib(type=str) + + def cursor( + self, *, txn_name=None, after_callbacks=None, exception_callbacks=None + ) -> "LoggingTransaction": + if not txn_name: + txn_name = self.default_txn_name + + return LoggingTransaction( + self.conn.cursor(), + name=txn_name, + database_engine=self.engine, + after_callbacks=after_callbacks, + exception_callbacks=exception_callbacks, + ) + + def close(self) -> None: + self.conn.close() + + def commit(self) -> None: + self.conn.commit() + + def rollback(self, *args, **kwargs) -> None: + self.conn.rollback(*args, **kwargs) + + def __enter__(self) -> "Connection": + self.conn.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> bool: + return self.conn.__exit__(exc_type, exc_value, traceback) + + # Proxy through any unknown lookups to the DB conn class. + def __getattr__(self, name): + return getattr(self.conn, name) + + # The type of entry which goes on our after_callbacks and exception_callbacks lists. # # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so @@ -247,6 +301,12 @@ class LoggingTransaction: def close(self) -> None: self.txn.close() + def __enter__(self) -> "LoggingTransaction": + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + class PerformanceCounters: def __init__(self): @@ -395,7 +455,7 @@ class DatabasePool: def new_transaction( self, - conn: Connection, + conn: LoggingDatabaseConnection, desc: str, after_callbacks: List[_CallbackListEntry], exception_callbacks: List[_CallbackListEntry], @@ -436,12 +496,10 @@ class DatabasePool: i = 0 N = 5 while True: - cursor = LoggingTransaction( - conn.cursor(), - name, - self.engine, - after_callbacks, - exception_callbacks, + cursor = conn.cursor( + txn_name=name, + after_callbacks=after_callbacks, + exception_callbacks=exception_callbacks, ) try: r = func(cursor, *args, **kwargs) @@ -638,7 +696,10 @@ class DatabasePool: if db_autocommit: self.engine.attempt_to_set_autocommit(conn, True) - return func(conn, *args, **kwargs) + db_conn = LoggingDatabaseConnection( + conn, self.engine, "runWithConnection" + ) + return func(db_conn, *args, **kwargs) finally: if db_autocommit: self.engine.attempt_to_set_autocommit(conn, False) @@ -1678,7 +1739,7 @@ class DatabasePool: def get_cache_dict( self, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, table: str, entity_column: str, stream_column: str, @@ -1699,9 +1760,7 @@ class DatabasePool: "limit": limit, } - sql = self.engine.convert_param_style(sql) - - txn = db_conn.cursor() + txn = db_conn.cursor(txn_name="get_cache_dict") txn.execute(sql, (int(max_value),)) cache = {row[0]: int(row[1]) for row in txn} diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index aa5d490624..0c24325011 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -46,7 +46,7 @@ class Databases: db_name = database_config.name engine = create_engine(database_config.config) - with make_conn(database_config, engine) as db_conn: + with make_conn(database_config, engine, "startup") as db_conn: logger.info("[database config %r]: Checking database server", db_name) engine.check_database(db_conn) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 0cb12f4c61..9b16f45f3e 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -15,9 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import calendar import logging -import time from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import PresenceState @@ -268,9 +266,6 @@ class DataStore( self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() - # Used in _generate_user_daily_visits to keep track of progress - self._last_user_visit_update = self._get_start_of_day() - def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() @@ -289,7 +284,6 @@ class DataStore( " last_user_sync_ts, status_msg, currently_active FROM presence_stream" " WHERE state != ?" ) - sql = self.database_engine.convert_param_style(sql) txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE,)) @@ -301,192 +295,6 @@ class DataStore( return [UserPresenceState(**row) for row in rows] - async def count_daily_users(self) -> int: - """ - Counts the number of users who used this homeserver in the last 24 hours. - """ - yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return await self.db_pool.runInteraction( - "count_daily_users", self._count_users, yesterday - ) - - async def count_monthly_users(self) -> int: - """ - Counts the number of users who used this homeserver in the last 30 days. - Note this method is intended for phonehome metrics only and is different - from the mau figure in synapse.storage.monthly_active_users which, - amongst other things, includes a 3 day grace period before a user counts. - """ - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return await self.db_pool.runInteraction( - "count_monthly_users", self._count_users, thirty_days_ago - ) - - def _count_users(self, txn, time_from): - """ - Returns number of users seen in the past time_from period - """ - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT user_id FROM user_ips - WHERE last_seen > ? - GROUP BY user_id - ) u - """ - txn.execute(sql, (time_from,)) - (count,) = txn.fetchone() - return count - - async def count_r30_users(self) -> Dict[str, int]: - """ - Counts the number of 30 day retained users, defined as:- - * Users who have created their accounts more than 30 days ago - * Where last seen at most 30 days ago - * Where account creation and last_seen are > 30 days apart - - Returns: - A mapping of counts globally as well as broken out by platform. - """ - - def _count_r30_users(txn): - thirty_days_in_secs = 86400 * 30 - now = int(self._clock.time()) - thirty_days_ago_in_secs = now - thirty_days_in_secs - - sql = """ - SELECT platform, COALESCE(count(*), 0) FROM ( - SELECT - users.name, platform, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen, - CASE - WHEN user_agent LIKE '%%Android%%' THEN 'android' - WHEN user_agent LIKE '%%iOS%%' THEN 'ios' - WHEN user_agent LIKE '%%Electron%%' THEN 'electron' - WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' - WHEN user_agent LIKE '%%Gecko%%' THEN 'web' - ELSE 'unknown' - END - AS platform - FROM user_ips - ) uip - ON users.name = uip.user_id - AND users.appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, platform, users.creation_ts - ) u GROUP BY platform - """ - - results = {} - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - for row in txn: - if row[0] == "unknown": - pass - results[row[0]] = row[1] - - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT users.name, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen - FROM user_ips - ) uip - ON users.name = uip.user_id - AND appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, users.creation_ts - ) u - """ - - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - (count,) = txn.fetchone() - results["all"] = count - - return results - - return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) - - def _get_start_of_day(self): - """ - Returns millisecond unixtime for start of UTC day. - """ - now = time.gmtime() - today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) - return today_start * 1000 - - async def generate_user_daily_visits(self) -> None: - """ - Generates daily visit data for use in cohort/ retention analysis - """ - - def _generate_user_daily_visits(txn): - logger.info("Calling _generate_user_daily_visits") - today_start = self._get_start_of_day() - a_day_in_milliseconds = 24 * 60 * 60 * 1000 - now = self.clock.time_msec() - - sql = """ - INSERT INTO user_daily_visits (user_id, device_id, timestamp) - SELECT u.user_id, u.device_id, ? - FROM user_ips AS u - LEFT JOIN ( - SELECT user_id, device_id, timestamp FROM user_daily_visits - WHERE timestamp = ? - ) udv - ON u.user_id = udv.user_id AND u.device_id=udv.device_id - INNER JOIN users ON users.name=u.user_id - WHERE last_seen > ? AND last_seen <= ? - AND udv.timestamp IS NULL AND users.is_guest=0 - AND users.appservice_id IS NULL - GROUP BY u.user_id, u.device_id - """ - - # This means that the day has rolled over but there could still - # be entries from the previous day. There is an edge case - # where if the user logs in at 23:59 and overwrites their - # last_seen at 00:01 then they will not be counted in the - # previous day's stats - it is important that the query is run - # often to minimise this case. - if today_start > self._last_user_visit_update: - yesterday_start = today_start - a_day_in_milliseconds - txn.execute( - sql, - ( - yesterday_start, - yesterday_start, - self._last_user_visit_update, - today_start, - ), - ) - self._last_user_visit_update = today_start - - txn.execute( - sql, (today_start, today_start, self._last_user_visit_update, now) - ) - # Update _last_user_visit_update to now. The reason to do this - # rather just clamping to the beginning of the day is to limit - # the size of the join - meaning that the query can be run more - # frequently - self._last_user_visit_update = now - - await self.db_pool.runInteraction( - "generate_user_daily_visits", _generate_user_daily_visits - ) - async def get_users(self) -> List[Dict[str, Any]]: """Function to retrieve a list of users in users table. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index ef81d73573..49ee23470d 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -18,6 +18,7 @@ import abc import logging from typing import Dict, List, Optional, Tuple +from synapse.api.constants import AccountDataTypes from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator @@ -291,14 +292,18 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext ) -> bool: ignored_account_data = await self.get_global_account_data_by_type_for_user( - "m.ignored_user_list", + AccountDataTypes.IGNORED_USER_LIST, ignorer_user_id, on_invalidate=cache_context.invalidate, ) if not ignored_account_data: return False - return ignored_user_id in ignored_account_data.get("ignored_users", {}) + try: + return ignored_user_id in ignored_account_data.get("ignored_users", {}) + except TypeError: + # The type of the ignored_users field is invalid. + return False class AccountDataStore(AccountDataWorkerStore): diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index f211ddbaf8..4bb2b9c28c 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -21,8 +21,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore -from synapse.storage.databases.main.events import encode_json from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util.frozenutils import frozendict_json_encoder if TYPE_CHECKING: from synapse.server import HomeServer @@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase and original_event.internal_metadata.is_redacted() ): # Redaction was allowed - pruned_json = encode_json( + pruned_json = frozendict_json_encoder.encode( prune_event_dict( original_event.room_version, original_event.get_dict() ) @@ -171,7 +171,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase return # Prune the event's dict then convert it to JSON. - pruned_json = encode_json( + pruned_json = frozendict_json_encoder.encode( prune_event_dict(event.room_version, event.get_dict()) ) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 239c7a949c..a25a888443 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -351,7 +351,63 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return updated -class ClientIpStore(ClientIpBackgroundUpdateStore): +class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self.user_ips_max_age = hs.config.user_ips_max_age + + if hs.config.run_background_tasks and self.user_ips_max_age: + self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) + + @wrap_as_background_process("prune_old_user_ips") + async def _prune_old_user_ips(self): + """Removes entries in user IPs older than the configured period. + """ + + if self.user_ips_max_age is None: + # Nothing to do + return + + if not await self.db_pool.updates.has_completed_background_update( + "devices_last_seen" + ): + # Only start pruning if we have finished populating the devices + # last seen info. + return + + # We do a slightly funky SQL delete to ensure we don't try and delete + # too much at once (as the table may be very large from before we + # started pruning). + # + # This works by finding the max last_seen that is less than the given + # time, but has no more than N rows before it, deleting all rows with + # a lesser last_seen time. (We COALESCE so that the sub-SELECT always + # returns exactly one row). + sql = """ + DELETE FROM user_ips + WHERE last_seen <= ( + SELECT COALESCE(MAX(last_seen), -1) + FROM ( + SELECT last_seen FROM user_ips + WHERE last_seen <= ? + ORDER BY last_seen ASC + LIMIT 5000 + ) AS u + ) + """ + + timestamp = self.clock.time_msec() - self.user_ips_max_age + + def _prune_old_user_ips_txn(txn): + txn.execute(sql, (timestamp,)) + + await self.db_pool.runInteraction( + "_prune_old_user_ips", _prune_old_user_ips_txn + ) + + +class ClientIpStore(ClientIpWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): self.client_ip_last_seen = Cache( @@ -360,8 +416,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): super().__init__(database, db_conn, hs) - self.user_ips_max_age = hs.config.user_ips_max_age - # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) self._batch_row_update = {} @@ -372,9 +426,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): "before", "shutdown", self._update_client_ips_batch ) - if self.user_ips_max_age: - self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) - async def insert_client_ip( self, user_id, access_token, ip, user_agent, device_id, now=None ): @@ -525,49 +576,3 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): } for (access_token, ip), (user_agent, last_seen) in results.items() ] - - @wrap_as_background_process("prune_old_user_ips") - async def _prune_old_user_ips(self): - """Removes entries in user IPs older than the configured period. - """ - - if self.user_ips_max_age is None: - # Nothing to do - return - - if not await self.db_pool.updates.has_completed_background_update( - "devices_last_seen" - ): - # Only start pruning if we have finished populating the devices - # last seen info. - return - - # We do a slightly funky SQL delete to ensure we don't try and delete - # too much at once (as the table may be very large from before we - # started pruning). - # - # This works by finding the max last_seen that is less than the given - # time, but has no more than N rows before it, deleting all rows with - # a lesser last_seen time. (We COALESCE so that the sub-SELECT always - # returns exactly one row). - sql = """ - DELETE FROM user_ips - WHERE last_seen <= ( - SELECT COALESCE(MAX(last_seen), -1) - FROM ( - SELECT last_seen FROM user_ips - WHERE last_seen <= ? - ORDER BY last_seen ASC - LIMIT 5000 - ) AS u - ) - """ - - timestamp = self.clock.time_msec() - self.user_ips_max_age - - def _prune_old_user_ips_txn(txn): - txn.execute(sql, (timestamp,)) - - await self.db_pool.runInteraction( - "_prune_old_user_ips", _prune_old_user_ips_txn - ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index fdf394c612..2d0a6408b5 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ from synapse.storage.database import ( make_tuple_comparison_clause, ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key -from synapse.util import json_encoder +from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import Cache, cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -698,6 +698,80 @@ class DeviceWorkerStore(SQLBaseStore): _mark_remote_user_device_list_as_unsubscribed_txn, ) + async def get_dehydrated_device( + self, user_id: str + ) -> Optional[Tuple[str, JsonDict]]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + # FIXME: make sure device ID still exists in devices table + row = await self.db_pool.simple_select_one( + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcols=["device_id", "device_data"], + allow_none=True, + ) + return ( + (row["device_id"], json_decoder.decode(row["device_data"])) if row else None + ) + + def _store_dehydrated_device_txn( + self, txn, user_id: str, device_id: str, device_data: str + ) -> Optional[str]: + old_device_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcol="device_id", + allow_none=True, + ) + self.db_pool.simple_upsert_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + values={"device_id": device_id, "device_data": device_data}, + ) + return old_device_id + + async def store_dehydrated_device( + self, user_id: str, device_id: str, device_data: JsonDict + ) -> Optional[str]: + """Store a dehydrated device for a user. + + Args: + user_id: the user that we are storing the device for + device_id: the ID of the dehydrated device + device_data: the dehydrated device information + Returns: + device id of the user's previous dehydrated device, if any + """ + return await self.db_pool.runInteraction( + "store_dehydrated_device_txn", + self._store_dehydrated_device_txn, + user_id, + device_id, + json_encoder.encode(device_data), + ) + + async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool: + """Remove a dehydrated device. + + Args: + user_id: the user that the dehydrated device belongs to + device_id: the ID of the dehydrated device + """ + count = await self.db_pool.simple_delete( + "dehydrated_devices", + {"user_id": user_id, "device_id": device_id}, + desc="remove_dehydrated_device", + ) + return count >= 1 + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -837,7 +911,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) async def store_device( - self, user_id: str, device_id: str, initial_device_display_name: str + self, user_id: str, device_id: str, initial_device_display_name: Optional[str] ) -> bool: """Ensure the given device is known; add it to the store if not @@ -955,7 +1029,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def update_remote_device_list_cache_entry( - self, user_id: str, device_id: str, content: JsonDict, stream_id: int + self, user_id: str, device_id: str, content: JsonDict, stream_id: str ) -> None: """Updates a single device in the cache of a remote user's devicelist. @@ -983,7 +1057,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_id: str, content: JsonDict, - stream_id: int, + stream_id: str, ) -> None: if content.get("deleted"): self.db_pool.simple_delete_txn( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 22e1ed15d0..359dc6e968 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -367,6 +367,57 @@ class EndToEndKeyWorkerStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + async def set_e2e_fallback_keys( + self, user_id: str, device_id: str, fallback_keys: JsonDict + ) -> None: + """Set the user's e2e fallback keys. + + Args: + user_id: the user whose keys are being set + device_id: the device whose keys are being set + fallback_keys: the keys to set. This is a map from key ID (which is + of the form "algorithm:id") to key data. + """ + # fallback_keys will usually only have one item in it, so using a for + # loop (as opposed to calling simple_upsert_many_txn) won't be too bad + # FIXME: make sure that only one key per algorithm is uploaded + for key_id, fallback_key in fallback_keys.items(): + algorithm, key_id = key_id.split(":", 1) + await self.db_pool.simple_upsert( + "e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + values={ + "key_id": key_id, + "key_json": json_encoder.encode(fallback_key), + "used": False, + }, + desc="set_e2e_fallback_key", + ) + + @cached(max_entries=10000) + async def get_e2e_unused_fallback_key_types( + self, user_id: str, device_id: str + ) -> List[str]: + """Returns the fallback key types that have an unused key. + + Args: + user_id: the user whose keys are being queried + device_id: the device whose keys are being queried + + Returns: + a list of key types + """ + return await self.db_pool.simple_select_onecol( + "e2e_fallback_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id, "used": False}, + retcol="algorithm", + desc="get_e2e_unused_fallback_key_types", + ) + async def get_e2e_cross_signing_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None ) -> Optional[dict]: @@ -701,15 +752,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): " WHERE user_id = ? AND device_id = ? AND algorithm = ?" " LIMIT 1" ) + fallback_sql = ( + "SELECT key_id, key_json, used FROM e2e_fallback_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " LIMIT 1" + ) result = {} delete = [] + used_fallbacks = [] for user_id, device_id, algorithm in query_list: user_result = result.setdefault(user_id, {}) device_result = user_result.setdefault(device_id, {}) txn.execute(sql, (user_id, device_id, algorithm)) - for key_id, key_json in txn: + otk_row = txn.fetchone() + if otk_row is not None: + key_id, key_json = otk_row device_result[algorithm + ":" + key_id] = key_json delete.append((user_id, device_id, algorithm, key_id)) + else: + # no one-time key available, so see if there's a fallback + # key + txn.execute(fallback_sql, (user_id, device_id, algorithm)) + fallback_row = txn.fetchone() + if fallback_row is not None: + key_id, key_json, used = fallback_row + device_result[algorithm + ":" + key_id] = key_json + if not used: + used_fallbacks.append( + (user_id, device_id, algorithm, key_id) + ) + + # drop any one-time keys that were claimed sql = ( "DELETE FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ? AND algorithm = ?" @@ -726,6 +799,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + # mark fallback keys as used + for user_id, device_id, algorithm, key_id in used_fallbacks: + self.db_pool.simple_update_txn( + txn, + "e2e_fallback_keys_json", + { + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + }, + {"used": True}, + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + ) + return result return await self.db_pool.runInteraction( @@ -754,6 +844,19 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + self.db_pool.simple_delete_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="e2e_fallback_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + ) await self.db_pool.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 62f1738732..80f3b4d740 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple, Union import attr from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -74,11 +74,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): self.stream_ordering_month_ago = None self.stream_ordering_day_ago = None - cur = LoggingTransaction( - db_conn.cursor(), - name="_find_stream_orderings_for_times_txn", - database_engine=self.database_engine, - ) + cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn") self._find_stream_orderings_for_times_txn(cur) cur.close() diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 18def01f50..b4abd961b9 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -52,16 +52,6 @@ event_counter = Counter( ) -def encode_json(json_object): - """ - Encode a Python object as JSON and return it in a Unicode string. - """ - out = frozendict_json_encoder.encode(json_object) - if isinstance(out, bytes): - out = out.decode("utf8") - return out - - _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) @@ -341,6 +331,10 @@ class PersistEventsStore: min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + # stream orderings should have been assigned by now + assert min_stream_order + assert max_stream_order + self._update_forward_extremities_txn( txn, new_forward_extremities=new_forward_extremeties, @@ -743,7 +737,9 @@ class PersistEventsStore: logger.exception("") raise - metadata_json = encode_json(event.internal_metadata.get_dict()) + metadata_json = frozendict_json_encoder.encode( + event.internal_metadata.get_dict() + ) sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" txn.execute(sql, (metadata_json, event.event_id)) @@ -797,10 +793,10 @@ class PersistEventsStore: { "event_id": event.event_id, "room_id": event.room_id, - "internal_metadata": encode_json( + "internal_metadata": frozendict_json_encoder.encode( event.internal_metadata.get_dict() ), - "json": encode_json(event_dict(event)), + "json": frozendict_json_encoder.encode(event_dict(event)), "format_version": event.format_version, } for event, _ in events_and_contexts diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index f95679ebc4..b7ed8ca6ab 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -74,6 +74,13 @@ class EventRedactBehaviour(Names): class EventsWorkerStore(SQLBaseStore): + # Whether to use dedicated DB threads for event fetching. This is only used + # if there are multiple DB threads available. When used will lock the DB + # thread for periods of time (so unit tests want to disable this when they + # run DB transactions on the main thread). See EVENT_QUEUE_* for more + # options controlling this. + USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True + def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -522,7 +529,11 @@ class EventsWorkerStore(SQLBaseStore): if not event_list: single_threaded = self.database_engine.single_threaded - if single_threaded or i > EVENT_QUEUE_ITERATIONS: + if ( + not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING + or single_threaded + or i > EVENT_QUEUE_ITERATIONS + ): self._event_fetch_ongoing -= 1 return else: @@ -712,6 +723,7 @@ class EventsWorkerStore(SQLBaseStore): internal_metadata_dict=internal_metadata, rejected_reason=rejected_reason, ) + original_ev.internal_metadata.stream_ordering = row["stream_ordering"] event_map[event_id] = original_ev @@ -779,6 +791,8 @@ class EventsWorkerStore(SQLBaseStore): * event_id (str) + * stream_ordering (int): stream ordering for this event + * json (str): json-encoded event structure * internal_metadata (str): json-encoded internal metadata dict @@ -811,13 +825,15 @@ class EventsWorkerStore(SQLBaseStore): sql = """\ SELECT e.event_id, - e.internal_metadata, - e.json, - e.format_version, + e.stream_ordering, + ej.internal_metadata, + ej.json, + ej.format_version, r.room_version, rej.reason - FROM event_json as e - LEFT JOIN rooms r USING (room_id) + FROM events AS e + JOIN event_json AS ej USING (event_id) + LEFT JOIN rooms r ON r.room_id = e.room_id LEFT JOIN rejections as rej USING (event_id) WHERE """ @@ -831,11 +847,12 @@ class EventsWorkerStore(SQLBaseStore): event_id = row[0] event_dict[event_id] = { "event_id": event_id, - "internal_metadata": row[1], - "json": row[2], - "format_version": row[3], - "room_version_id": row[4], - "rejected_reason": row[5], + "stream_ordering": row[1], + "internal_metadata": row[2], + "json": row[3], + "format_version": row[4], + "room_version_id": row[5], + "rejected_reason": row[6], "redactions": [], } diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 92099f95ce..0acf0617ca 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -12,15 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import calendar +import logging +import time +from typing import Dict from synapse.metrics import GaugeBucketCollector -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) +logger = logging.getLogger(__name__) + # Collect metrics on the number of forward extremities that exist. _extremities_collecter = GaugeBucketCollector( "synapse_forward_extremities", @@ -51,15 +57,13 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): super().__init__(database, db_conn, hs) # Read the extrems every 60 minutes - def read_forward_extremities(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "read_forward_extremities", self._read_forward_extremities - ) + if hs.config.run_background_tasks: + self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000) - hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) + # Used in _generate_user_daily_visits to keep track of progress + self._last_user_visit_update = self._get_start_of_day() + @wrap_as_background_process("read_forward_extremities") async def _read_forward_extremities(self): def fetch(txn): txn.execute( @@ -137,3 +141,190 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): return count return await self.db_pool.runInteraction("count_daily_active_rooms", _count) + + async def count_daily_users(self) -> int: + """ + Counts the number of users who used this homeserver in the last 24 hours. + """ + yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) + return await self.db_pool.runInteraction( + "count_daily_users", self._count_users, yesterday + ) + + async def count_monthly_users(self) -> int: + """ + Counts the number of users who used this homeserver in the last 30 days. + Note this method is intended for phonehome metrics only and is different + from the mau figure in synapse.storage.monthly_active_users which, + amongst other things, includes a 3 day grace period before a user counts. + """ + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + return await self.db_pool.runInteraction( + "count_monthly_users", self._count_users, thirty_days_ago + ) + + def _count_users(self, txn, time_from): + """ + Returns number of users seen in the past time_from period + """ + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM user_ips + WHERE last_seen > ? + GROUP BY user_id + ) u + """ + txn.execute(sql, (time_from,)) + (count,) = txn.fetchone() + return count + + async def count_r30_users(self) -> Dict[str, int]: + """ + Counts the number of 30 day retained users, defined as:- + * Users who have created their accounts more than 30 days ago + * Where last seen at most 30 days ago + * Where account creation and last_seen are > 30 days apart + + Returns: + A mapping of counts globally as well as broken out by platform. + """ + + def _count_r30_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + thirty_days_ago_in_secs = now - thirty_days_in_secs + + sql = """ + SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT + users.name, platform, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM user_ips + ) uip + ON users.name = uip.user_id + AND users.appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, platform, users.creation_ts + ) u GROUP BY platform + """ + + results = {} + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + for row in txn: + if row[0] == "unknown": + pass + results[row[0]] = row[1] + + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT users.name, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen + FROM user_ips + ) uip + ON users.name = uip.user_id + AND appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, users.creation_ts + ) u + """ + + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + (count,) = txn.fetchone() + results["all"] = count + + return results + + return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) + + def _get_start_of_day(self): + """ + Returns millisecond unixtime for start of UTC day. + """ + now = time.gmtime() + today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) + return today_start * 1000 + + @wrap_as_background_process("generate_user_daily_visits") + async def generate_user_daily_visits(self) -> None: + """ + Generates daily visit data for use in cohort/ retention analysis + """ + + def _generate_user_daily_visits(txn): + logger.info("Calling _generate_user_daily_visits") + today_start = self._get_start_of_day() + a_day_in_milliseconds = 24 * 60 * 60 * 1000 + now = self._clock.time_msec() + + sql = """ + INSERT INTO user_daily_visits (user_id, device_id, timestamp) + SELECT u.user_id, u.device_id, ? + FROM user_ips AS u + LEFT JOIN ( + SELECT user_id, device_id, timestamp FROM user_daily_visits + WHERE timestamp = ? + ) udv + ON u.user_id = udv.user_id AND u.device_id=udv.device_id + INNER JOIN users ON users.name=u.user_id + WHERE last_seen > ? AND last_seen <= ? + AND udv.timestamp IS NULL AND users.is_guest=0 + AND users.appservice_id IS NULL + GROUP BY u.user_id, u.device_id + """ + + # This means that the day has rolled over but there could still + # be entries from the previous day. There is an edge case + # where if the user logs in at 23:59 and overwrites their + # last_seen at 00:01 then they will not be counted in the + # previous day's stats - it is important that the query is run + # often to minimise this case. + if today_start > self._last_user_visit_update: + yesterday_start = today_start - a_day_in_milliseconds + txn.execute( + sql, + ( + yesterday_start, + yesterday_start, + self._last_user_visit_update, + today_start, + ), + ) + self._last_user_visit_update = today_start + + txn.execute( + sql, (today_start, today_start, self._last_user_visit_update, now) + ) + # Update _last_user_visit_update to now. The reason to do this + # rather just clamping to the beginning of the day is to limit + # the size of the join - meaning that the query can be run more + # frequently + self._last_user_visit_update = now + + await self.db_pool.runInteraction( + "generate_user_daily_visits", _generate_user_daily_visits + ) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e93aad33cd..c66f558567 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -32,6 +32,9 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs + self._limit_usage_by_mau = hs.config.limit_usage_by_mau + self._max_mau_value = hs.config.max_mau_value + @cached(num_args=0) async def get_monthly_active_count(self) -> int: """Generates current count of monthly active users @@ -124,60 +127,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): desc="user_last_seen_monthly_active", ) - -class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - - self._limit_usage_by_mau = hs.config.limit_usage_by_mau - self._mau_stats_only = hs.config.mau_stats_only - self._max_mau_value = hs.config.max_mau_value - - # Do not add more reserved users than the total allowable number - # cur = LoggingTransaction( - self.db_pool.new_transaction( - db_conn, - "initialise_mau_threepids", - [], - [], - self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[: self._max_mau_value], - ) - - def _initialise_reserved_users(self, txn, threepids): - """Ensures that reserved threepids are accounted for in the MAU table, should - be called on start up. - - Args: - txn (cursor): - threepids (list[dict]): List of threepid dicts to reserve - """ - - # XXX what is this function trying to achieve? It upserts into - # monthly_active_users for each *registered* reserved mau user, but why? - # - # - shouldn't there already be an entry for each reserved user (at least - # if they have been active recently)? - # - # - if it's important that the timestamp is kept up to date, why do we only - # run this at startup? - - for tp in threepids: - user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) - - if user_id: - is_support = self.is_support_user_txn(txn, user_id) - if not is_support: - # We do this manually here to avoid hitting #6791 - self.db_pool.simple_upsert_txn( - txn, - table="monthly_active_users", - keyvalues={"user_id": user_id}, - values={"timestamp": int(self._clock.time_msec())}, - ) - else: - logger.warning("mau limit reserved threepid %s not found in db" % tp) - async def reap_monthly_active_users(self): """Cleans out monthly active user table to ensure that no stale entries exist. @@ -257,6 +206,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): "reap_monthly_active_users", _reap_users, reserved_users ) + +class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self._mau_stats_only = hs.config.mau_stats_only + + # Do not add more reserved users than the total allowable number + self.db_pool.new_transaction( + db_conn, + "initialise_mau_threepids", + [], + [], + self._initialise_reserved_users, + hs.config.mau_limits_reserved_threepids[: self._max_mau_value], + ) + + def _initialise_reserved_users(self, txn, threepids): + """Ensures that reserved threepids are accounted for in the MAU table, should + be called on start up. + + Args: + txn (cursor): + threepids (list[dict]): List of threepid dicts to reserve + """ + + # XXX what is this function trying to achieve? It upserts into + # monthly_active_users for each *registered* reserved mau user, but why? + # + # - shouldn't there already be an entry for each reserved user (at least + # if they have been active recently)? + # + # - if it's important that the timestamp is kept up to date, why do we only + # run this at startup? + + for tp in threepids: + user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) + + if user_id: + is_support = self.is_support_user_txn(txn, user_id) + if not is_support: + # We do this manually here to avoid hitting #6791 + self.db_pool.simple_upsert_txn( + txn, + table="monthly_active_users", + keyvalues={"user_id": user_id}, + values={"timestamp": int(self._clock.time_msec())}, + ) + else: + logger.warning("mau limit reserved threepid %s not found in db" % tp) + async def upsert_monthly_active_user(self, user_id: str) -> None: """Updates or inserts the user into the monthly active user table, which is used to track the current MAU usage of the server diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index a83df7759d..a85867936f 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,14 +14,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import re from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.types import Cursor @@ -48,6 +50,21 @@ class RegistrationWorkerStore(SQLBaseStore): database.engine, find_max_generated_user_id_localpart, "user_id_seq", ) + self._account_validity = hs.config.account_validity + if hs.config.run_background_tasks and self._account_validity.enabled: + self._clock.call_later( + 0.0, + run_as_background_process, + "account_validity_set_expiration_dates", + self._set_expiration_date_when_missing, + ) + + # Create a background job for culling expired 3PID validity tokens + if hs.config.run_background_tasks: + self.clock.looping_call( + self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS + ) + @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: return await self.db_pool.simple_select_one( @@ -778,6 +795,78 @@ class RegistrationWorkerStore(SQLBaseStore): "delete_threepid_session", delete_threepid_session_txn ) + @wrap_as_background_process("cull_expired_threepid_validation_tokens") + async def cull_expired_threepid_validation_tokens(self) -> None: + """Remove threepid validation tokens with expiry dates that have passed""" + + def cull_expired_threepid_validation_tokens_txn(txn, ts): + sql = """ + DELETE FROM threepid_validation_token WHERE + expires < ? + """ + txn.execute(sql, (ts,)) + + await self.db_pool.runInteraction( + "cull_expired_threepid_validation_tokens", + cull_expired_threepid_validation_tokens_txn, + self.clock.time_msec(), + ) + + async def _set_expiration_date_when_missing(self): + """ + Retrieves the list of registered users that don't have an expiration date, and + adds an expiration date for each of them. + """ + + def select_users_with_no_expiration_date_txn(txn): + """Retrieves the list of registered users with no expiration date from the + database, filtering out deactivated users. + """ + sql = ( + "SELECT users.name FROM users" + " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" + " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" + ) + txn.execute(sql, []) + + res = self.db_pool.cursor_to_dict(txn) + if res: + for user in res: + self.set_expiration_date_for_user_txn( + txn, user["name"], use_delta=True + ) + + await self.db_pool.runInteraction( + "get_users_with_no_expiration_date", + select_users_with_no_expiration_date_txn, + ) + + def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): + """Sets an expiration date to the account with the given user ID. + + Args: + user_id (str): User ID to set an expiration date for. + use_delta (bool): If set to False, the expiration date for the user will be + now + validity period. If set to True, this expiration date will be a + random value in the [now + period - d ; now + period] range, d being a + delta equal to 10% of the validity period. + """ + now_ms = self._clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + + if use_delta: + expiration_ts = self.rand.randrange( + expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts, + ) + + self.db_pool.simple_upsert_txn( + txn, + "account_validity", + keyvalues={"user_id": user_id}, + values={"expiration_ts_ms": expiration_ts, "email_sent": False}, + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -911,28 +1000,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self._account_validity = hs.config.account_validity self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors - if self._account_validity.enabled: - self._clock.call_later( - 0.0, - run_as_background_process, - "account_validity_set_expiration_dates", - self._set_expiration_date_when_missing, - ) - - # Create a background job for culling expired 3PID validity tokens - def start_cull(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "cull_expired_threepid_validation_tokens", - self.cull_expired_threepid_validation_tokens, - ) - - hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) - async def add_access_token_to_user( self, user_id: str, @@ -964,6 +1033,36 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="add_access_token_to_user", ) + def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: + old_device_id = self.db_pool.simple_select_one_onecol_txn( + txn, "access_tokens", {"token": token}, "device_id" + ) + + self.db_pool.simple_update_txn( + txn, "access_tokens", {"token": token}, {"device_id": device_id} + ) + + self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,)) + + return old_device_id + + async def set_device_for_access_token(self, token: str, device_id: str) -> str: + """Sets the device ID associated with an access token. + + Args: + token: The access token to modify. + device_id: The new device ID. + Returns: + The old device ID associated with the access token. + """ + + return await self.db_pool.runInteraction( + "set_device_for_access_token", + self._set_device_for_access_token_txn, + token, + device_id, + ) + async def register_user( self, user_id: str, @@ -1447,22 +1546,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): start_or_continue_validation_session_txn, ) - async def cull_expired_threepid_validation_tokens(self) -> None: - """Remove threepid validation tokens with expiry dates that have passed""" - - def cull_expired_threepid_validation_tokens_txn(txn, ts): - sql = """ - DELETE FROM threepid_validation_token WHERE - expires < ? - """ - txn.execute(sql, (ts,)) - - await self.db_pool.runInteraction( - "cull_expired_threepid_validation_tokens", - cull_expired_threepid_validation_tokens_txn, - self.clock.time_msec(), - ) - async def set_user_deactivated_status( self, user_id: str, deactivated: bool ) -> None: @@ -1492,61 +1575,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) txn.call_after(self.is_guest.invalidate, (user_id,)) - async def _set_expiration_date_when_missing(self): - """ - Retrieves the list of registered users that don't have an expiration date, and - adds an expiration date for each of them. - """ - - def select_users_with_no_expiration_date_txn(txn): - """Retrieves the list of registered users with no expiration date from the - database, filtering out deactivated users. - """ - sql = ( - "SELECT users.name FROM users" - " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" - " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" - ) - txn.execute(sql, []) - - res = self.db_pool.cursor_to_dict(txn) - if res: - for user in res: - self.set_expiration_date_for_user_txn( - txn, user["name"], use_delta=True - ) - - await self.db_pool.runInteraction( - "get_users_with_no_expiration_date", - select_users_with_no_expiration_date_txn, - ) - - def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): - """Sets an expiration date to the account with the given user ID. - - Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be - now + validity period. If set to True, this expiration date will be a - random value in the [now + period - d ; now + period] range, d being a - delta equal to 10% of the validity period. - """ - now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period - - if use_delta: - expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, - expiration_ts, - ) - - self.db_pool.simple_upsert_txn( - txn, - "account_validity", - keyvalues={"user_id": user_id}, - values={"expiration_ts_ms": expiration_ts, "email_sent": False}, - ) - def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 3c7630857f..c0f2af0785 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -192,6 +192,18 @@ class RoomWorkerStore(SQLBaseStore): "count_public_rooms", _count_public_rooms_txn ) + async def get_room_count(self) -> int: + """Retrieve the total number of rooms. + """ + + def f(txn): + sql = "SELECT count(*) FROM rooms" + txn.execute(sql) + row = txn.fetchone() + return row[0] or 0 + + return await self.db_pool.runInteraction("get_rooms", f) + async def get_largest_public_rooms( self, network_tuple: Optional[ThirdPartyInstanceID], @@ -1292,18 +1304,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) self.hs.get_notifier().on_new_replication_data() - async def get_room_count(self) -> int: - """Retrieve the total number of rooms. - """ - - def f(txn): - sql = "SELECT count(*) FROM rooms" - txn.execute(sql) - row = txn.fetchone() - return row[0] or 0 - - return await self.db_pool.runInteraction("get_rooms", f) - async def add_event_report( self, room_id: str, diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 86ffe2479e..20fcdaa529 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -21,12 +21,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import ( - LoggingTransaction, - SQLBaseStore, - db_to_json, - make_in_list_sql_clause, -) +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine @@ -60,15 +55,16 @@ class RoomMemberWorkerStore(EventsWorkerStore): # background update still running? self._current_state_events_membership_up_to_date = False - txn = LoggingTransaction( - db_conn.cursor(), - name="_check_safe_current_state_events_membership_updated", - database_engine=self.database_engine, + txn = db_conn.cursor( + txn_name="_check_safe_current_state_events_membership_updated" ) self._check_safe_current_state_events_membership_updated_txn(txn) txn.close() - if self.hs.config.metrics_flags.known_servers: + if ( + self.hs.config.run_background_tasks + and self.hs.config.metrics_flags.known_servers + ): self._known_servers_count = 1 self.hs.get_clock().looping_call( run_as_background_process, diff --git a/synapse/storage/databases/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py index 3edfcfd783..45b846e6a7 100644 --- a/synapse/storage/databases/main/schema/delta/20/pushers.py +++ b/synapse/storage/databases/main/schema/delta/20/pushers.py @@ -66,16 +66,15 @@ def run_create(cur, database_engine, *args, **kwargs): row[8] = bytes(row[8]).decode("utf-8") row[11] = bytes(row[11]).decode("utf-8") cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_token, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + ) values (%s) + """ + % (",".join(["?" for _ in range(len(row))])), row, ) count += 1 diff --git a/synapse/storage/databases/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py index ee675e71ff..21f57825d4 100644 --- a/synapse/storage/databases/main/schema/delta/25/fts.py +++ b/synapse/storage/databases/main/schema/delta/25/fts.py @@ -71,8 +71,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_search", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py index b7972cfa8e..1c6058063f 100644 --- a/synapse/storage/databases/main/schema/delta/27/ts.py +++ b/synapse/storage/databases/main/schema/delta/27/ts.py @@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_origin_server_ts", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py index b42c02710a..7f08fabe9f 100644 --- a/synapse/storage/databases/main/schema/delta/30/as_users.py +++ b/synapse/storage/databases/main/schema/delta/30/as_users.py @@ -59,9 +59,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n)) for chunk in user_chunks: cur.execute( - database_engine.convert_param_style( - "UPDATE users SET appservice_id = ? WHERE name IN (%s)" - % (",".join("?" for _ in chunk),) - ), + "UPDATE users SET appservice_id = ? WHERE name IN (%s)" + % (",".join("?" for _ in chunk),), [as_id] + chunk, ) diff --git a/synapse/storage/databases/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py index 9bb504aad5..5be81c806a 100644 --- a/synapse/storage/databases/main/schema/delta/31/pushers.py +++ b/synapse/storage/databases/main/schema/delta/31/pushers.py @@ -65,16 +65,15 @@ def run_create(cur, database_engine, *args, **kwargs): row = list(row) row[12] = token_to_stream_ordering(row[12]) cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_stream_ordering, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_stream_ordering, last_success, + failing_since + ) values (%s) + """ + % (",".join(["?" for _ in range(len(row))])), row, ) count += 1 diff --git a/synapse/storage/databases/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py index 63b757ade6..b84c844e3a 100644 --- a/synapse/storage/databases/main/schema/delta/31/search_update.py +++ b/synapse/storage/databases/main/schema/delta/31/search_update.py @@ -55,8 +55,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_search_order", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py index a3e81eeac7..e928c66a8f 100644 --- a/synapse/storage/databases/main/schema/delta/33/event_fields.py +++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py @@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_fields_sender_url", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py index a26057dfb6..ad875c733a 100644 --- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py @@ -23,8 +23,5 @@ def run_create(cur, database_engine, *args, **kwargs): def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute( - database_engine.convert_param_style( - "UPDATE remote_media_cache SET last_access_ts = ?" - ), - (int(time.time() * 1000),), + "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),), ) diff --git a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py index 1de8b54961..bb7296852a 100644 --- a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py +++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py @@ -1,6 +1,8 @@ import logging +from io import StringIO from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import execute_statements_from_stream logger = logging.getLogger(__name__) @@ -46,7 +48,4 @@ def run_create(cur, database_engine, *args, **kwargs): select_clause, ) - if isinstance(database_engine, PostgresEngine): - cur.execute(sql) - else: - cur.executescript(sql) + execute_statements_from_stream(cur, StringIO(sql)) diff --git a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py index 63b5acdcf7..44917f0a2e 100644 --- a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py +++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py @@ -68,7 +68,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): INNER JOIN room_memberships AS r USING (event_id) WHERE type = 'm.room.member' AND state_key LIKE ? """ - sql = database_engine.convert_param_style(sql) cur.execute(sql, ("%:" + config.server_name,)) cur.execute( diff --git a/synapse/storage/databases/main/schema/delta/58/11dehydration.sql b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql new file mode 100644 index 0000000000..7851a0a825 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql @@ -0,0 +1,20 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS dehydrated_devices( + user_id TEXT NOT NULL PRIMARY KEY, + device_id TEXT NOT NULL, + device_data TEXT NOT NULL -- JSON-encoded client-defined data +); diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql new file mode 100644 index 0000000000..4ed981dbf8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql @@ -0,0 +1,24 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json ( + user_id TEXT NOT NULL, -- The user this fallback key is for. + device_id TEXT NOT NULL, -- The device this fallback key is for. + algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for. + key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. + key_json TEXT NOT NULL, -- The key as a JSON blob. + used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not. + CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm) +); diff --git a/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres new file mode 100644 index 0000000000..841186b826 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres @@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A unique and immutable mapping between instance name and an integer ID. This +-- lets us refer to instances via a small ID in e.g. stream tokens, without +-- having to encode the full name. +CREATE TABLE IF NOT EXISTS instance_map ( + instance_id SERIAL PRIMARY KEY, + instance_name TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name); diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 37249f1e3f..e3b9ff5ca6 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -53,7 +53,9 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import Collection, PersistedEventPosition, RoomStreamToken +from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache if TYPE_CHECKING: @@ -208,6 +210,55 @@ def _make_generic_sql_bound( ) +def _filter_results( + lower_token: Optional[RoomStreamToken], + upper_token: Optional[RoomStreamToken], + instance_name: str, + topological_ordering: int, + stream_ordering: int, +) -> bool: + """Returns True if the event persisted by the given instance at the given + topological/stream_ordering falls between the two tokens (taking a None + token to mean unbounded). + + Used to filter results from fetching events in the DB against the given + tokens. This is necessary to handle the case where the tokens include + position maps, which we handle by fetching more than necessary from the DB + and then filtering (rather than attempting to construct a complicated SQL + query). + """ + + event_historical_tuple = ( + topological_ordering, + stream_ordering, + ) + + if lower_token: + if lower_token.topological is not None: + # If these are historical tokens we compare the `(topological, stream)` + # tuples. + if event_historical_tuple <= lower_token.as_historical_tuple(): + return False + + else: + # If these are live tokens we compare the stream ordering against the + # writers stream position. + if stream_ordering <= lower_token.get_stream_pos_for_instance( + instance_name + ): + return False + + if upper_token: + if upper_token.topological is not None: + if upper_token.as_historical_tuple() < event_historical_tuple: + return False + else: + if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering: + return False + + return True + + def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create @@ -305,7 +356,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): raise NotImplementedError() def get_room_max_token(self) -> RoomStreamToken: - return RoomStreamToken(None, self.get_room_max_stream_ordering()) + """Get a `RoomStreamToken` that marks the current maximum persisted + position of the events stream. Useful to get a token that represents + "now". + + The token returned is a "live" token that may have an instance_map + component. + """ + + min_pos = self._stream_id_gen.get_current_token() + + positions = {} + if isinstance(self._stream_id_gen, MultiWriterIdGenerator): + # The `min_pos` is the minimum position that we know all instances + # have finished persisting to, so we only care about instances whose + # positions are ahead of that. (Instance positions can be behind the + # min position as there are times we can work out that the minimum + # position is ahead of the naive minimum across all current + # positions. See MultiWriterIdGenerator for details) + positions = { + i: p + for i, p in self._stream_id_gen.get_positions().items() + if p > min_pos + } + + return RoomStreamToken(None, min_pos, positions) async def get_room_events_stream_for_rooms( self, @@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if from_key == to_key: return [], from_key - from_id = from_key.stream - to_id = to_key.stream - - has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) + has_changed = self._events_stream_cache.has_entity_changed( + room_id, from_key.stream + ) if not has_changed: return [], from_key def f(txn): - sql = ( - "SELECT event_id, stream_ordering FROM events WHERE" - " room_id = ?" - " AND not outlier" - " AND stream_ordering > ? AND stream_ordering <= ?" - " ORDER BY stream_ordering %s LIMIT ?" - ) % (order,) - txn.execute(sql, (room_id, from_id, to_id, limit)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + sql = """ + SELECT event_id, instance_name, topological_ordering, stream_ordering + FROM events + WHERE + room_id = ? + AND not outlier + AND stream_ordering > ? AND stream_ordering <= ? + ORDER BY stream_ordering %s LIMIT ? + """ % ( + order, + ) + txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit)) + + rows = [ + _EventDictReturn(event_id, None, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + from_key, + to_key, + instance_name, + topological_ordering, + stream_ordering, + ) + ][:limit] return rows rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) @@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): [r.event_id for r in rows], get_prev_content=True ) - self._set_before_and_after(ret, rows, topo_order=from_id is None) + self._set_before_and_after(ret, rows, topo_order=False) if order.lower() == "desc": ret.reverse() @@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: - from_id = from_key.stream - to_id = to_key.stream - if from_key == to_key: return [] - if from_id: + if from_key: has_changed = self._membership_stream_cache.has_entity_changed( - user_id, int(from_id) + user_id, int(from_key.stream) ) if not has_changed: return [] def f(txn): - sql = ( - "SELECT m.event_id, stream_ordering FROM events AS e," - " room_memberships AS m" - " WHERE e.event_id = m.event_id" - " AND m.user_id = ?" - " AND e.stream_ordering > ? AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - ) - txn.execute(sql, (user_id, from_id, to_id)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + sql = """ + SELECT m.event_id, instance_name, topological_ordering, stream_ordering + FROM events AS e, room_memberships AS m + WHERE e.event_id = m.event_id + AND m.user_id = ? + AND e.stream_ordering > ? AND e.stream_ordering <= ? + ORDER BY e.stream_ordering ASC + """ + txn.execute(sql, (user_id, min_from_id, max_to_id,)) + + rows = [ + _EventDictReturn(event_id, None, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + from_key, + to_key, + instance_name, + topological_ordering, + stream_ordering, + ) + ] return rows @@ -546,7 +651,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_room_event_before_stream_ordering( self, room_id: str, stream_ordering: int - ) -> Tuple[int, int, str]: + ) -> Optional[Tuple[int, int, str]]: """Gets details of the first event in a room at or before a stream ordering Args: @@ -589,19 +694,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): ) return "t%d-%d" % (topo, token) - async def get_stream_id_for_event(self, event_id: str) -> int: - """The stream ID for an event - Args: - event_id: The id of the event to look up a stream token for. - Raises: - StoreError if the event wasn't in the database. - Returns: - A stream ID. - """ - return await self.db_pool.runInteraction( - "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id, - ) - def get_stream_id_for_event_txn( self, txn: LoggingTransaction, event_id: str, allow_none=False, ) -> int: @@ -979,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): else: order = "ASC" + # The bounds for the stream tokens are complicated by the fact + # that we need to handle the instance_map part of the tokens. We do this + # by fetching all events between the min stream token and the maximum + # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and + # then filtering the results. + if from_token.topological is not None: + from_bound = ( + from_token.as_historical_tuple() + ) # type: Tuple[Optional[int], int] + elif direction == "b": + from_bound = ( + None, + from_token.get_max_stream_pos(), + ) + else: + from_bound = ( + None, + from_token.stream, + ) + + to_bound = None # type: Optional[Tuple[Optional[int], int]] + if to_token: + if to_token.topological is not None: + to_bound = to_token.as_historical_tuple() + elif direction == "b": + to_bound = ( + None, + to_token.stream, + ) + else: + to_bound = ( + None, + to_token.get_max_stream_pos(), + ) + bounds = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token.as_tuple(), - to_token=to_token.as_tuple() if to_token else None, + from_token=from_bound, + to_token=to_bound, engine=self.database_engine, ) @@ -993,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): bounds += " AND " + filter_clause args.extend(filter_args) - args.append(int(limit)) + # We fetch more events as we'll filter the result set + args.append(int(limit) * 2) select_keywords = "SELECT" join_clause = "" @@ -1015,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): select_keywords += "DISTINCT" sql = """ - %(select_keywords)s event_id, topological_ordering, stream_ordering + %(select_keywords)s + event_id, instance_name, + topological_ordering, stream_ordering FROM events %(join_clause)s WHERE outlier = ? AND room_id = ? AND %(bounds)s @@ -1030,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): txn.execute(sql, args) - rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] + # Filter the result set. + rows = [ + _EventDictReturn(event_id, topological_ordering, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + lower_token=to_token if direction == "b" else from_token, + upper_token=from_token if direction == "b" else to_token, + instance_name=instance_name, + topological_ordering=topological_ordering, + stream_ordering=stream_ordering, + ) + ][:limit] if rows: topo = rows[-1].topological_ordering @@ -1095,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): return (events, token) + @cached() + async def get_id_for_instance(self, instance_name: str) -> int: + """Get a unique, immutable ID that corresponds to the given Synapse worker instance. + """ + + def _get_id_for_instance_txn(txn): + instance_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + retcol="instance_id", + allow_none=True, + ) + if instance_id is not None: + return instance_id + + # If we don't have an entry upsert one. + # + # We could do this before the first check, and rely on the cache for + # efficiency, but each UPSERT causes the next ID to increment which + # can quickly bloat the size of the generated IDs for new instances. + self.db_pool.simple_upsert_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + values={}, + ) + + return self.db_pool.simple_select_one_onecol_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + retcol="instance_id", + ) + + return await self.db_pool.runInteraction( + "get_id_for_instance", _get_id_for_instance_txn + ) + + @cached() + async def get_name_from_instance_id(self, instance_id: int) -> str: + """Get the instance name from an ID previously returned by + `get_id_for_instance`. + """ + + return await self.db_pool.simple_select_one_onecol( + table="instance_map", + keyvalues={"instance_id": instance_id}, + retcol="instance_name", + desc="get_name_from_instance_id", + ) + class StreamStore(StreamWorkerStore): def get_room_max_stream_ordering(self) -> int: diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 97aed1500e..7d46090267 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -19,7 +19,7 @@ from typing import Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -43,15 +43,33 @@ _UpdateTransactionRow = namedtuple( SENTINEL = object() -class TransactionStore(SQLBaseStore): +class TransactionWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + if hs.config.run_background_tasks: + self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) + + @wrap_as_background_process("cleanup_transactions") + async def _cleanup_transactions(self) -> None: + now = self._clock.time_msec() + month_ago = now - 30 * 24 * 60 * 60 * 1000 + + def _cleanup_transactions_txn(txn): + txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) + + await self.db_pool.runInteraction( + "_cleanup_transactions", _cleanup_transactions_txn + ) + + +class TransactionStore(TransactionWorkerStore): """A collection of queries for handling PDUs. """ def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) - self._destination_retry_cache = ExpiringCache( cache_name="get_destination_retry_timings", clock=self._clock, @@ -266,22 +284,6 @@ class TransactionStore(SQLBaseStore): }, ) - def _start_cleanup_transactions(self): - return run_as_background_process( - "cleanup_transactions", self._cleanup_transactions - ) - - async def _cleanup_transactions(self) -> None: - now = self._clock.time_msec() - month_ago = now - 30 * 24 * 60 * 60 * 1000 - - def _cleanup_transactions_txn(txn): - txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - - await self.db_pool.runInteraction( - "_cleanup_transactions", _cleanup_transactions_txn - ) - async def store_destination_rooms_entries( self, destinations: Iterable[str], room_id: str, stream_ordering: int, ) -> None: diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 3b9211a6d2..79b7ece330 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -288,8 +288,6 @@ class UIAuthWorkerStore(SQLBaseStore): ) return [(row["user_agent"], row["ip"]) for row in rows] - -class UIAuthStore(UIAuthWorkerStore): async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ Remove sessions which were last used earlier than the expiration time. @@ -339,3 +337,7 @@ class UIAuthStore(UIAuthWorkerStore): iterable=session_ids, keyvalues={}, ) + + +class UIAuthStore(UIAuthWorkerStore): + pass diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 72939f3984..4d2d88d1f0 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -248,6 +248,8 @@ class EventsPersistenceStorage: await make_deferred_yieldable(deferred) event_stream_id = event.internal_metadata.stream_ordering + # stream ordering should have been assigned by now + assert event_stream_id pos = PersistedEventPosition(self._instance_name, event_stream_id) return pos, self.main_store.get_room_max_token() diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 4957e77f4c..459754feab 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import imp import logging import os @@ -24,9 +23,10 @@ from typing import Optional, TextIO import attr from synapse.config.homeserver import HomeServerConfig +from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.engines.postgres import PostgresEngine -from synapse.storage.types import Connection, Cursor +from synapse.storage.types import Cursor from synapse.types import Collection logger = logging.getLogger(__name__) @@ -67,7 +67,7 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = ( def prepare_database( - db_conn: Connection, + db_conn: LoggingDatabaseConnection, database_engine: BaseDatabaseEngine, config: Optional[HomeServerConfig], databases: Collection[str] = ["main", "state"], @@ -89,7 +89,7 @@ def prepare_database( """ try: - cur = db_conn.cursor() + cur = db_conn.cursor(txn_name="prepare_database") # sqlite does not automatically start transactions for DDL / SELECT statements, # so we start one before running anything. This ensures that any upgrades @@ -258,9 +258,7 @@ def _setup_new_database(cur, database_engine, databases): executescript(cur, entry.absolute_path) cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded) VALUES (?,?)" - ), + "INSERT INTO schema_version (version, upgraded) VALUES (?,?)", (max_current_ver, False), ) @@ -486,17 +484,13 @@ def _upgrade_existing_database( # Mark as done. cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)" - ), + "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)", (v, relative_path), ) cur.execute("DELETE FROM schema_version") cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded) VALUES (?,?)" - ), + "INSERT INTO schema_version (version, upgraded) VALUES (?,?)", (v, True), ) @@ -532,10 +526,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) schemas to be applied """ cur.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_module_schemas WHERE module_name = ?" - ), - (modname,), + "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,), ) applied_deltas = {d for d, in cur} for (name, stream) in names_and_streams: @@ -553,9 +544,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) # Mark as done. cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)" - ), + "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)", (modname, name), ) @@ -627,9 +616,7 @@ def _get_or_create_schema_state(txn, database_engine): if current_version: txn.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_schema_deltas WHERE version >= ?" - ), + "SELECT file FROM applied_schema_deltas WHERE version >= ?", (current_version,), ) applied_deltas = [d for d, in txn] diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 2d2b560e74..970bb1b9da 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -61,3 +61,9 @@ class Connection(Protocol): def rollback(self, *args, **kwargs) -> None: ... + + def __enter__(self) -> "Connection": + ... + + def __exit__(self, exc_type, exc_value, traceback) -> bool: + ... diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index ad017207aa..d7e40aaa8b 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -55,7 +55,7 @@ def _load_current_id(db_conn, table, column, step=1): """ # debug logging for https://github.com/matrix-org/synapse/issues/7968 logger.info("initialising stream generator for %s(%s)", table, column) - cur = db_conn.cursor() + cur = db_conn.cursor(txn_name="_load_current_id") if step == 1: cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) else: @@ -270,7 +270,7 @@ class MultiWriterIdGenerator: def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str ): - cur = db_conn.cursor() + cur = db_conn.cursor(txn_name="_load_current_ids") # Load the current positions of all writers for the stream. if self._writers: @@ -284,15 +284,12 @@ class MultiWriterIdGenerator: stream_name = ? AND instance_name != ALL(?) """ - sql = self._db.engine.convert_param_style(sql) cur.execute(sql, (self._stream_name, self._writers)) sql = """ SELECT instance_name, stream_id FROM stream_positions WHERE stream_name = ? """ - sql = self._db.engine.convert_param_style(sql) - cur.execute(sql, (self._stream_name,)) self._current_positions = { @@ -341,7 +338,6 @@ class MultiWriterIdGenerator: "instance": instance_column, "cmp": "<=" if self._positive else ">=", } - sql = self._db.engine.convert_param_style(sql) cur.execute(sql, (min_stream_id * self._return_factor,)) self._persisted_upto_position = min_stream_id @@ -422,7 +418,7 @@ class MultiWriterIdGenerator: self._unfinished_ids.discard(next_id) self._finished_ids.add(next_id) - new_cur = None + new_cur = None # type: Optional[int] if self._unfinished_ids: # If there are unfinished IDs then the new position will be the diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 2dd95e2709..ff2d038ad2 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -17,6 +17,7 @@ import logging import threading from typing import Callable, List, Optional +from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import ( BaseDatabaseEngine, IncorrectDatabaseSetup, @@ -53,7 +54,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta): @abc.abstractmethod def check_consistency( - self, db_conn: Connection, table: str, id_column: str, positive: bool = True + self, + db_conn: LoggingDatabaseConnection, + table: str, + id_column: str, + positive: bool = True, ): """Should be called during start up to test that the current value of the sequence is greater than or equal to the maximum ID in the table. @@ -82,9 +87,13 @@ class PostgresSequenceGenerator(SequenceGenerator): return [i for (i,) in txn] def check_consistency( - self, db_conn: Connection, table: str, id_column: str, positive: bool = True + self, + db_conn: LoggingDatabaseConnection, + table: str, + id_column: str, + positive: bool = True, ): - txn = db_conn.cursor() + txn = db_conn.cursor(txn_name="sequence.check_consistency") # First we get the current max ID from the table. table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % { |