diff options
Diffstat (limited to 'synapse/storage')
64 files changed, 1980 insertions, 866 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 79ec8f119d..0ba3a025cf 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -32,6 +32,7 @@ from typing import ( overload, ) +import attr from prometheus_client import Histogram from typing_extensions import Literal @@ -90,13 +91,17 @@ def make_pool( return adbapi.ConnectionPool( db_config.config["name"], cp_reactor=reactor, - cp_openfun=engine.on_new_connection, + cp_openfun=lambda conn: engine.on_new_connection( + LoggingDatabaseConnection(conn, engine, "on_new_connection") + ), **db_config.config.get("args", {}) ) def make_conn( - db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + db_config: DatabaseConnectionConfig, + engine: BaseDatabaseEngine, + default_txn_name: str, ) -> Connection: """Make a new connection to the database and return it. @@ -109,11 +114,60 @@ def make_conn( for k, v in db_config.config.get("args", {}).items() if not k.startswith("cp_") } - db_conn = engine.module.connect(**db_params) + native_db_conn = engine.module.connect(**db_params) + db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name) + engine.on_new_connection(db_conn) return db_conn +@attr.s(slots=True) +class LoggingDatabaseConnection: + """A wrapper around a database connection that returns `LoggingTransaction` + as its cursor class. + + This is mainly used on startup to ensure that queries get logged correctly + """ + + conn = attr.ib(type=Connection) + engine = attr.ib(type=BaseDatabaseEngine) + default_txn_name = attr.ib(type=str) + + def cursor( + self, *, txn_name=None, after_callbacks=None, exception_callbacks=None + ) -> "LoggingTransaction": + if not txn_name: + txn_name = self.default_txn_name + + return LoggingTransaction( + self.conn.cursor(), + name=txn_name, + database_engine=self.engine, + after_callbacks=after_callbacks, + exception_callbacks=exception_callbacks, + ) + + def close(self) -> None: + self.conn.close() + + def commit(self) -> None: + self.conn.commit() + + def rollback(self, *args, **kwargs) -> None: + self.conn.rollback(*args, **kwargs) + + def __enter__(self) -> "Connection": + self.conn.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> bool: + return self.conn.__exit__(exc_type, exc_value, traceback) + + # Proxy through any unknown lookups to the DB conn class. + def __getattr__(self, name): + return getattr(self.conn, name) + + # The type of entry which goes on our after_callbacks and exception_callbacks lists. # # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so @@ -247,6 +301,12 @@ class LoggingTransaction: def close(self) -> None: self.txn.close() + def __enter__(self) -> "LoggingTransaction": + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + class PerformanceCounters: def __init__(self): @@ -395,7 +455,7 @@ class DatabasePool: def new_transaction( self, - conn: Connection, + conn: LoggingDatabaseConnection, desc: str, after_callbacks: List[_CallbackListEntry], exception_callbacks: List[_CallbackListEntry], @@ -403,6 +463,24 @@ class DatabasePool: *args: Any, **kwargs: Any ) -> R: + """Start a new database transaction with the given connection. + + Note: The given func may be called multiple times under certain + failure modes. This is normally fine when in a standard transaction, + but care must be taken if the connection is in `autocommit` mode that + the function will correctly handle being aborted and retried half way + through its execution. + + Args: + conn + desc + after_callbacks + exception_callbacks + func + *args + **kwargs + """ + start = monotonic_time() txn_id = self._TXN_ID @@ -418,12 +496,10 @@ class DatabasePool: i = 0 N = 5 while True: - cursor = LoggingTransaction( - conn.cursor(), - name, - self.engine, - after_callbacks, - exception_callbacks, + cursor = conn.cursor( + txn_name=name, + after_callbacks=after_callbacks, + exception_callbacks=exception_callbacks, ) try: r = func(cursor, *args, **kwargs) @@ -508,7 +584,12 @@ class DatabasePool: sql_txn_timer.labels(desc).observe(duration) async def runInteraction( - self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any + self, + desc: str, + func: "Callable[..., R]", + *args: Any, + db_autocommit: bool = False, + **kwargs: Any ) -> R: """Starts a transaction on the database and runs a given function @@ -518,6 +599,18 @@ class DatabasePool: database transaction (twisted.enterprise.adbapi.Transaction) as its first argument, followed by `args` and `kwargs`. + db_autocommit: Whether to run the function in "autocommit" mode, + i.e. outside of a transaction. This is useful for transactions + that are only a single query. + + Currently, this is only implemented for Postgres. SQLite will still + run the function inside a transaction. + + WARNING: This means that if func fails half way through then + the changes will *not* be rolled back. `func` may also get + called multiple times if the transaction is retried, so must + correctly handle that case. + args: positional args to pass to `func` kwargs: named args to pass to `func` @@ -538,6 +631,7 @@ class DatabasePool: exception_callbacks, func, *args, + db_autocommit=db_autocommit, **kwargs ) @@ -551,7 +645,11 @@ class DatabasePool: return cast(R, result) async def runWithConnection( - self, func: "Callable[..., R]", *args: Any, **kwargs: Any + self, + func: "Callable[..., R]", + *args: Any, + db_autocommit: bool = False, + **kwargs: Any ) -> R: """Wraps the .runWithConnection() method on the underlying db_pool. @@ -560,6 +658,9 @@ class DatabasePool: database connection (twisted.enterprise.adbapi.Connection) as its first argument, followed by `args` and `kwargs`. args: positional args to pass to `func` + db_autocommit: Whether to run the function in "autocommit" mode, + i.e. outside of a transaction. This is useful for transaction + that are only a single query. Currently only affects postgres. kwargs: named args to pass to `func` Returns: @@ -575,6 +676,13 @@ class DatabasePool: start_time = monotonic_time() def inner_func(conn, *args, **kwargs): + # We shouldn't be in a transaction. If we are then something + # somewhere hasn't committed after doing work. (This is likely only + # possible during startup, as `run*` will ensure changes are + # committed/rolled back before putting the connection back in the + # pool). + assert not self.engine.in_transaction(conn) + with LoggingContext("runWithConnection", parent_context) as context: sched_duration_sec = monotonic_time() - start_time sql_scheduling_timer.observe(sched_duration_sec) @@ -584,7 +692,17 @@ class DatabasePool: logger.debug("Reconnecting closed database connection") conn.reconnect() - return func(conn, *args, **kwargs) + try: + if db_autocommit: + self.engine.attempt_to_set_autocommit(conn, True) + + db_conn = LoggingDatabaseConnection( + conn, self.engine, "runWithConnection" + ) + return func(db_conn, *args, **kwargs) + finally: + if db_autocommit: + self.engine.attempt_to_set_autocommit(conn, False) return await make_deferred_yieldable( self._db_pool.runWithConnection(inner_func, *args, **kwargs) @@ -1621,7 +1739,7 @@ class DatabasePool: def get_cache_dict( self, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, table: str, entity_column: str, stream_column: str, @@ -1642,9 +1760,7 @@ class DatabasePool: "limit": limit, } - sql = self.engine.convert_param_style(sql) - - txn = db_conn.cursor() + txn = db_conn.cursor(txn_name="get_cache_dict") txn.execute(sql, (int(max_value),)) cache = {row[0]: int(row[1]) for row in txn} diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index aa5d490624..0c24325011 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -46,7 +46,7 @@ class Databases: db_name = database_config.name engine = create_engine(database_config.config) - with make_conn(database_config, engine) as db_conn: + with make_conn(database_config, engine, "startup") as db_conn: logger.info("[database config %r]: Checking database server", db_name) engine.check_database(db_conn) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 2ae2fbd5d7..9b16f45f3e 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -15,9 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import calendar import logging -import time from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import PresenceState @@ -160,19 +158,25 @@ class DataStore( ) if isinstance(self.database_engine, PostgresEngine): + # We set the `writers` to an empty list here as we don't care about + # missing updates over restarts, as we'll not have anything in our + # caches to invalidate. (This reduces the amount of writes to the DB + # that happen). self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, - instance_name="master", + stream_name="caches", + instance_name=hs.get_instance_name(), table="cache_invalidation_stream_by_instance", instance_column="instance_name", id_column="stream_id", sequence_name="cache_invalidation_stream_seq", + writers=[], ) else: self._cache_id_gen = None - super(DataStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._presence_on_startup = self._get_active_presence(db_conn) @@ -262,9 +266,6 @@ class DataStore( self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() - # Used in _generate_user_daily_visits to keep track of progress - self._last_user_visit_update = self._get_start_of_day() - def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() @@ -283,7 +284,6 @@ class DataStore( " last_user_sync_ts, status_msg, currently_active FROM presence_stream" " WHERE state != ?" ) - sql = self.database_engine.convert_param_style(sql) txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE,)) @@ -295,192 +295,6 @@ class DataStore( return [UserPresenceState(**row) for row in rows] - async def count_daily_users(self) -> int: - """ - Counts the number of users who used this homeserver in the last 24 hours. - """ - yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return await self.db_pool.runInteraction( - "count_daily_users", self._count_users, yesterday - ) - - async def count_monthly_users(self) -> int: - """ - Counts the number of users who used this homeserver in the last 30 days. - Note this method is intended for phonehome metrics only and is different - from the mau figure in synapse.storage.monthly_active_users which, - amongst other things, includes a 3 day grace period before a user counts. - """ - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return await self.db_pool.runInteraction( - "count_monthly_users", self._count_users, thirty_days_ago - ) - - def _count_users(self, txn, time_from): - """ - Returns number of users seen in the past time_from period - """ - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT user_id FROM user_ips - WHERE last_seen > ? - GROUP BY user_id - ) u - """ - txn.execute(sql, (time_from,)) - (count,) = txn.fetchone() - return count - - async def count_r30_users(self) -> Dict[str, int]: - """ - Counts the number of 30 day retained users, defined as:- - * Users who have created their accounts more than 30 days ago - * Where last seen at most 30 days ago - * Where account creation and last_seen are > 30 days apart - - Returns: - A mapping of counts globally as well as broken out by platform. - """ - - def _count_r30_users(txn): - thirty_days_in_secs = 86400 * 30 - now = int(self._clock.time()) - thirty_days_ago_in_secs = now - thirty_days_in_secs - - sql = """ - SELECT platform, COALESCE(count(*), 0) FROM ( - SELECT - users.name, platform, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen, - CASE - WHEN user_agent LIKE '%%Android%%' THEN 'android' - WHEN user_agent LIKE '%%iOS%%' THEN 'ios' - WHEN user_agent LIKE '%%Electron%%' THEN 'electron' - WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' - WHEN user_agent LIKE '%%Gecko%%' THEN 'web' - ELSE 'unknown' - END - AS platform - FROM user_ips - ) uip - ON users.name = uip.user_id - AND users.appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, platform, users.creation_ts - ) u GROUP BY platform - """ - - results = {} - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - for row in txn: - if row[0] == "unknown": - pass - results[row[0]] = row[1] - - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT users.name, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen - FROM user_ips - ) uip - ON users.name = uip.user_id - AND appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, users.creation_ts - ) u - """ - - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - (count,) = txn.fetchone() - results["all"] = count - - return results - - return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) - - def _get_start_of_day(self): - """ - Returns millisecond unixtime for start of UTC day. - """ - now = time.gmtime() - today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) - return today_start * 1000 - - async def generate_user_daily_visits(self) -> None: - """ - Generates daily visit data for use in cohort/ retention analysis - """ - - def _generate_user_daily_visits(txn): - logger.info("Calling _generate_user_daily_visits") - today_start = self._get_start_of_day() - a_day_in_milliseconds = 24 * 60 * 60 * 1000 - now = self.clock.time_msec() - - sql = """ - INSERT INTO user_daily_visits (user_id, device_id, timestamp) - SELECT u.user_id, u.device_id, ? - FROM user_ips AS u - LEFT JOIN ( - SELECT user_id, device_id, timestamp FROM user_daily_visits - WHERE timestamp = ? - ) udv - ON u.user_id = udv.user_id AND u.device_id=udv.device_id - INNER JOIN users ON users.name=u.user_id - WHERE last_seen > ? AND last_seen <= ? - AND udv.timestamp IS NULL AND users.is_guest=0 - AND users.appservice_id IS NULL - GROUP BY u.user_id, u.device_id - """ - - # This means that the day has rolled over but there could still - # be entries from the previous day. There is an edge case - # where if the user logs in at 23:59 and overwrites their - # last_seen at 00:01 then they will not be counted in the - # previous day's stats - it is important that the query is run - # often to minimise this case. - if today_start > self._last_user_visit_update: - yesterday_start = today_start - a_day_in_milliseconds - txn.execute( - sql, - ( - yesterday_start, - yesterday_start, - self._last_user_visit_update, - today_start, - ), - ) - self._last_user_visit_update = today_start - - txn.execute( - sql, (today_start, today_start, self._last_user_visit_update, now) - ) - # Update _last_user_visit_update to now. The reason to do this - # rather just clamping to the beginning of the day is to limit - # the size of the join - meaning that the query can be run more - # frequently - self._last_user_visit_update = now - - await self.db_pool.runInteraction( - "generate_user_daily_visits", _generate_user_daily_visits - ) - async def get_users(self) -> List[Dict[str, Any]]: """Function to retrieve a list of users in users table. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 4436b1a83d..49ee23470d 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -18,6 +18,7 @@ import abc import logging from typing import Dict, List, Optional, Tuple +from synapse.api.constants import AccountDataTypes from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator @@ -29,22 +30,20 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -class AccountDataWorkerStore(SQLBaseStore): +# The ABCMeta metaclass ensures that it cannot be instantiated without +# the abstract methods being implemented. +class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): """This is an abstract base class where subclasses must implement `get_max_account_data_stream_id` which can be called in the initializer. """ - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - def __init__(self, database: DatabasePool, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max ) - super(AccountDataWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): @@ -293,14 +292,18 @@ class AccountDataWorkerStore(SQLBaseStore): self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext ) -> bool: ignored_account_data = await self.get_global_account_data_by_type_for_user( - "m.ignored_user_list", + AccountDataTypes.IGNORED_USER_LIST, ignorer_user_id, on_invalidate=cache_context.invalidate, ) if not ignored_account_data: return False - return ignored_user_id in ignored_account_data.get("ignored_users", {}) + try: + return ignored_user_id in ignored_account_data.get("ignored_users", {}) + except TypeError: + # The type of the ignored_users field is invalid. + return False class AccountDataStore(AccountDataWorkerStore): @@ -315,7 +318,7 @@ class AccountDataStore(AccountDataWorkerStore): ], ) - super(AccountDataStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_max_account_data_stream_id(self) -> int: """Get the current max stream id for the private user data stream @@ -341,7 +344,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. @@ -389,7 +392,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 454c0bc50c..85f6b1e3fd 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -52,7 +52,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_app_services(self): return self.services_cache diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index f211ddbaf8..4bb2b9c28c 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -21,8 +21,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore -from synapse.storage.databases.main.events import encode_json from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util.frozenutils import frozendict_json_encoder if TYPE_CHECKING: from synapse.server import HomeServer @@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase and original_event.internal_metadata.is_redacted() ): # Redaction was allowed - pruned_json = encode_json( + pruned_json = frozendict_json_encoder.encode( prune_event_dict( original_event.room_version, original_event.get_dict() ) @@ -171,7 +171,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase return # Prune the event's dict then convert it to JSON. - pruned_json = encode_json( + pruned_json = frozendict_json_encoder.encode( prune_event_dict(event.room_version, event.get_dict()) ) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index c2fc847fbc..a25a888443 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -31,7 +31,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000 class ClientIpBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( "user_ips_device_index", @@ -351,16 +351,70 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return updated -class ClientIpStore(ClientIpBackgroundUpdateStore): +class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self.user_ips_max_age = hs.config.user_ips_max_age + + if hs.config.run_background_tasks and self.user_ips_max_age: + self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) + + @wrap_as_background_process("prune_old_user_ips") + async def _prune_old_user_ips(self): + """Removes entries in user IPs older than the configured period. + """ + + if self.user_ips_max_age is None: + # Nothing to do + return + + if not await self.db_pool.updates.has_completed_background_update( + "devices_last_seen" + ): + # Only start pruning if we have finished populating the devices + # last seen info. + return + + # We do a slightly funky SQL delete to ensure we don't try and delete + # too much at once (as the table may be very large from before we + # started pruning). + # + # This works by finding the max last_seen that is less than the given + # time, but has no more than N rows before it, deleting all rows with + # a lesser last_seen time. (We COALESCE so that the sub-SELECT always + # returns exactly one row). + sql = """ + DELETE FROM user_ips + WHERE last_seen <= ( + SELECT COALESCE(MAX(last_seen), -1) + FROM ( + SELECT last_seen FROM user_ips + WHERE last_seen <= ? + ORDER BY last_seen ASC + LIMIT 5000 + ) AS u + ) + """ + + timestamp = self.clock.time_msec() - self.user_ips_max_age + + def _prune_old_user_ips_txn(txn): + txn.execute(sql, (timestamp,)) + + await self.db_pool.runInteraction( + "_prune_old_user_ips", _prune_old_user_ips_txn + ) + + +class ClientIpStore(ClientIpWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 ) - super(ClientIpStore, self).__init__(database, db_conn, hs) - - self.user_ips_max_age = hs.config.user_ips_max_age + super().__init__(database, db_conn, hs) # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) self._batch_row_update = {} @@ -372,9 +426,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): "before", "shutdown", self._update_client_ips_batch ) - if self.user_ips_max_age: - self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) - async def insert_client_ip( self, user_id, access_token, ip, user_agent, device_id, now=None ): @@ -525,49 +576,3 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): } for (access_token, ip), (user_agent, last_seen) in results.items() ] - - @wrap_as_background_process("prune_old_user_ips") - async def _prune_old_user_ips(self): - """Removes entries in user IPs older than the configured period. - """ - - if self.user_ips_max_age is None: - # Nothing to do - return - - if not await self.db_pool.updates.has_completed_background_update( - "devices_last_seen" - ): - # Only start pruning if we have finished populating the devices - # last seen info. - return - - # We do a slightly funky SQL delete to ensure we don't try and delete - # too much at once (as the table may be very large from before we - # started pruning). - # - # This works by finding the max last_seen that is less than the given - # time, but has no more than N rows before it, deleting all rows with - # a lesser last_seen time. (We COALESCE so that the sub-SELECT always - # returns exactly one row). - sql = """ - DELETE FROM user_ips - WHERE last_seen <= ( - SELECT COALESCE(MAX(last_seen), -1) - FROM ( - SELECT last_seen FROM user_ips - WHERE last_seen <= ? - ORDER BY last_seen ASC - LIMIT 5000 - ) AS u - ) - """ - - timestamp = self.clock.time_msec() - self.user_ips_max_age - - def _prune_old_user_ips_txn(txn): - txn.execute(sql, (timestamp,)) - - await self.db_pool.runInteraction( - "_prune_old_user_ips", _prune_old_user_ips_txn - ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 0044433110..d42faa3f1f 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -283,7 +283,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( "device_inbox_stream_index", @@ -313,7 +313,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceInboxStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. @@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) - with await self._device_inbox_id_gen.get_next() as stream_id: + async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id @@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) txn, stream_id, local_messages_by_user_then_device ) - with await self._device_inbox_id_gen.get_next() as stream_id: + async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 306fc6947c..2d0a6408b5 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ from synapse.storage.database import ( make_tuple_comparison_clause, ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key -from synapse.util import json_encoder +from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import Cache, cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore): THe new stream ID. """ - with await self._device_list_id_gen.get_next() as stream_id: + async with self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -698,10 +698,84 @@ class DeviceWorkerStore(SQLBaseStore): _mark_remote_user_device_list_as_unsubscribed_txn, ) + async def get_dehydrated_device( + self, user_id: str + ) -> Optional[Tuple[str, JsonDict]]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + # FIXME: make sure device ID still exists in devices table + row = await self.db_pool.simple_select_one( + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcols=["device_id", "device_data"], + allow_none=True, + ) + return ( + (row["device_id"], json_decoder.decode(row["device_data"])) if row else None + ) + + def _store_dehydrated_device_txn( + self, txn, user_id: str, device_id: str, device_data: str + ) -> Optional[str]: + old_device_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcol="device_id", + allow_none=True, + ) + self.db_pool.simple_upsert_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + values={"device_id": device_id, "device_data": device_data}, + ) + return old_device_id + + async def store_dehydrated_device( + self, user_id: str, device_id: str, device_data: JsonDict + ) -> Optional[str]: + """Store a dehydrated device for a user. + + Args: + user_id: the user that we are storing the device for + device_id: the ID of the dehydrated device + device_data: the dehydrated device information + Returns: + device id of the user's previous dehydrated device, if any + """ + return await self.db_pool.runInteraction( + "store_dehydrated_device_txn", + self._store_dehydrated_device_txn, + user_id, + device_id, + json_encoder.encode(device_data), + ) + + async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool: + """Remove a dehydrated device. + + Args: + user_id: the user that the dehydrated device belongs to + device_id: the ID of the dehydrated device + """ + count = await self.db_pool.simple_delete( + "dehydrated_devices", + {"user_id": user_id, "device_id": device_id}, + desc="remove_dehydrated_device", + ) + return count >= 1 + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( "device_lists_stream_idx", @@ -826,7 +900,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. @@ -837,7 +911,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) async def store_device( - self, user_id: str, device_id: str, initial_device_display_name: str + self, user_id: str, device_id: str, initial_device_display_name: Optional[str] ) -> bool: """Ensure the given device is known; add it to the store if not @@ -955,7 +1029,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def update_remote_device_list_cache_entry( - self, user_id: str, device_id: str, content: JsonDict, stream_id: int + self, user_id: str, device_id: str, content: JsonDict, stream_id: str ) -> None: """Updates a single device in the cache of a remote user's devicelist. @@ -983,7 +1057,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_id: str, content: JsonDict, - stream_id: int, + stream_id: str, ) -> None: if content.get("deleted"): self.db_pool.simple_delete_txn( @@ -1093,7 +1167,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return - with await self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( @@ -1108,7 +1182,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] context = get_active_span_text_map() - with await self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c8df0bcb3f..359dc6e968 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -367,6 +367,57 @@ class EndToEndKeyWorkerStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + async def set_e2e_fallback_keys( + self, user_id: str, device_id: str, fallback_keys: JsonDict + ) -> None: + """Set the user's e2e fallback keys. + + Args: + user_id: the user whose keys are being set + device_id: the device whose keys are being set + fallback_keys: the keys to set. This is a map from key ID (which is + of the form "algorithm:id") to key data. + """ + # fallback_keys will usually only have one item in it, so using a for + # loop (as opposed to calling simple_upsert_many_txn) won't be too bad + # FIXME: make sure that only one key per algorithm is uploaded + for key_id, fallback_key in fallback_keys.items(): + algorithm, key_id = key_id.split(":", 1) + await self.db_pool.simple_upsert( + "e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + values={ + "key_id": key_id, + "key_json": json_encoder.encode(fallback_key), + "used": False, + }, + desc="set_e2e_fallback_key", + ) + + @cached(max_entries=10000) + async def get_e2e_unused_fallback_key_types( + self, user_id: str, device_id: str + ) -> List[str]: + """Returns the fallback key types that have an unused key. + + Args: + user_id: the user whose keys are being queried + device_id: the device whose keys are being queried + + Returns: + a list of key types + """ + return await self.db_pool.simple_select_onecol( + "e2e_fallback_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id, "used": False}, + retcol="algorithm", + desc="get_e2e_unused_fallback_key_types", + ) + async def get_e2e_cross_signing_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None ) -> Optional[dict]: @@ -701,15 +752,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): " WHERE user_id = ? AND device_id = ? AND algorithm = ?" " LIMIT 1" ) + fallback_sql = ( + "SELECT key_id, key_json, used FROM e2e_fallback_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " LIMIT 1" + ) result = {} delete = [] + used_fallbacks = [] for user_id, device_id, algorithm in query_list: user_result = result.setdefault(user_id, {}) device_result = user_result.setdefault(device_id, {}) txn.execute(sql, (user_id, device_id, algorithm)) - for key_id, key_json in txn: + otk_row = txn.fetchone() + if otk_row is not None: + key_id, key_json = otk_row device_result[algorithm + ":" + key_id] = key_json delete.append((user_id, device_id, algorithm, key_id)) + else: + # no one-time key available, so see if there's a fallback + # key + txn.execute(fallback_sql, (user_id, device_id, algorithm)) + fallback_row = txn.fetchone() + if fallback_row is not None: + key_id, key_json, used = fallback_row + device_result[algorithm + ":" + key_id] = key_json + if not used: + used_fallbacks.append( + (user_id, device_id, algorithm, key_id) + ) + + # drop any one-time keys that were claimed sql = ( "DELETE FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ? AND algorithm = ?" @@ -726,6 +799,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + # mark fallback keys as used + for user_id, device_id, algorithm, key_id in used_fallbacks: + self.db_pool.simple_update_txn( + txn, + "e2e_fallback_keys_json", + { + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + }, + {"used": True}, + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + ) + return result return await self.db_pool.runInteraction( @@ -754,6 +844,19 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + self.db_pool.simple_delete_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="e2e_fallback_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + ) await self.db_pool.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn @@ -831,7 +934,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key (dict): the key data """ - with await self._cross_signing_id_gen.get_next() as stream_id: + async with self._cross_signing_id_gen.get_next() as stream_id: return await self.db_pool.runInteraction( "add_e2e_cross_signing_key", self._set_e2e_cross_signing_key_txn, diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 4c3c162acf..6d3689c09e 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -600,7 +600,7 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" def __init__(self, database: DatabasePool, db_conn, hs): - super(EventFederationStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 7805fb814e..80f3b4d740 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple, Union import attr from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -68,17 +68,13 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn self.stream_ordering_month_ago = None self.stream_ordering_day_ago = None - cur = LoggingTransaction( - db_conn.cursor(), - name="_find_stream_orderings_for_times_txn", - database_engine=self.database_engine, - ) + cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn") self._find_stream_orderings_for_times_txn(cur) cur.close() @@ -661,7 +657,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" def __init__(self, database: DatabasePool, db_conn, hs): - super(EventPushActionsStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 9a80f419e3..b4abd961b9 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -17,7 +17,7 @@ import itertools import logging from collections import OrderedDict, namedtuple -from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple import attr from prometheus_client import Counter @@ -52,16 +52,6 @@ event_counter = Counter( ) -def encode_json(json_object): - """ - Encode a Python object as JSON and return it in a Unicode string. - """ - out = frozendict_json_encoder.encode(json_object) - if isinstance(out, bytes): - out = out.decode("utf8") - return out - - _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) @@ -156,15 +146,15 @@ class PersistEventsStore: # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. if backfilled: - stream_ordering_manager = await self._backfill_id_gen.get_next_mult( + stream_ordering_manager = self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) else: - stream_ordering_manager = await self._stream_id_gen.get_next_mult( + stream_ordering_manager = self._stream_id_gen.get_next_mult( len(events_and_contexts) ) - with stream_ordering_manager as stream_orderings: + async with stream_ordering_manager as stream_orderings: for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream @@ -341,6 +331,10 @@ class PersistEventsStore: min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + # stream orderings should have been assigned by now + assert min_stream_order + assert max_stream_order + self._update_forward_extremities_txn( txn, new_forward_extremities=new_forward_extremeties, @@ -743,7 +737,9 @@ class PersistEventsStore: logger.exception("") raise - metadata_json = encode_json(event.internal_metadata.get_dict()) + metadata_json = frozendict_json_encoder.encode( + event.internal_metadata.get_dict() + ) sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" txn.execute(sql, (metadata_json, event.event_id)) @@ -797,10 +793,10 @@ class PersistEventsStore: { "event_id": event.event_id, "room_id": event.room_id, - "internal_metadata": encode_json( + "internal_metadata": frozendict_json_encoder.encode( event.internal_metadata.get_dict() ), - "json": encode_json(event_dict(event)), + "json": frozendict_json_encoder.encode(event_dict(event)), "format_version": event.format_version, } for event, _ in events_and_contexts @@ -1108,6 +1104,10 @@ class PersistEventsStore: def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ + + def str_or_none(val: Any) -> Optional[str]: + return val if isinstance(val, str) else None + self.db_pool.simple_insert_many_txn( txn, table="room_memberships", @@ -1118,8 +1118,8 @@ class PersistEventsStore: "sender": event.user_id, "room_id": event.room_id, "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), + "display_name": str_or_none(event.content.get("displayname")), + "avatar_url": str_or_none(event.content.get("avatar_url")), } for event in events ], diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index e53c6373a8..5e4af2eb51 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -29,7 +29,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" def __init__(self, database: DatabasePool, db_conn, hs): - super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 17f5997b89..b7ed8ca6ab 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import division - import itertools import logging import threading @@ -76,8 +74,15 @@ class EventRedactBehaviour(Names): class EventsWorkerStore(SQLBaseStore): + # Whether to use dedicated DB threads for event fetching. This is only used + # if there are multiple DB threads available. When used will lock the DB + # thread for periods of time (so unit tests want to disable this when they + # run DB transactions on the main thread). See EVENT_QUEUE_* for more + # options controlling this. + USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True + def __init__(self, database: DatabasePool, db_conn, hs): - super(EventsWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) if isinstance(database.engine, PostgresEngine): # If we're using Postgres than we can use `MultiWriterIdGenerator` @@ -85,21 +90,25 @@ class EventsWorkerStore(SQLBaseStore): self._stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + stream_name="events", instance_name=hs.get_instance_name(), table="events", instance_column="instance_name", id_column="stream_ordering", sequence_name="events_stream_seq", + writers=hs.config.worker.writers.events, ) self._backfill_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + stream_name="backfill", instance_name=hs.get_instance_name(), table="events", instance_column="instance_name", id_column="stream_ordering", sequence_name="events_backfill_stream_seq", positive=False, + writers=hs.config.worker.writers.events, ) else: # We shouldn't be running in worker mode with SQLite, but its useful @@ -520,7 +529,11 @@ class EventsWorkerStore(SQLBaseStore): if not event_list: single_threaded = self.database_engine.single_threaded - if single_threaded or i > EVENT_QUEUE_ITERATIONS: + if ( + not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING + or single_threaded + or i > EVENT_QUEUE_ITERATIONS + ): self._event_fetch_ongoing -= 1 return else: @@ -710,6 +723,7 @@ class EventsWorkerStore(SQLBaseStore): internal_metadata_dict=internal_metadata, rejected_reason=rejected_reason, ) + original_ev.internal_metadata.stream_ordering = row["stream_ordering"] event_map[event_id] = original_ev @@ -777,6 +791,8 @@ class EventsWorkerStore(SQLBaseStore): * event_id (str) + * stream_ordering (int): stream ordering for this event + * json (str): json-encoded event structure * internal_metadata (str): json-encoded internal metadata dict @@ -809,13 +825,15 @@ class EventsWorkerStore(SQLBaseStore): sql = """\ SELECT e.event_id, - e.internal_metadata, - e.json, - e.format_version, + e.stream_ordering, + ej.internal_metadata, + ej.json, + ej.format_version, r.room_version, rej.reason - FROM event_json as e - LEFT JOIN rooms r USING (room_id) + FROM events AS e + JOIN event_json AS ej USING (event_id) + LEFT JOIN rooms r ON r.room_id = e.room_id LEFT JOIN rejections as rej USING (event_id) WHERE """ @@ -829,11 +847,12 @@ class EventsWorkerStore(SQLBaseStore): event_id = row[0] event_dict[event_id] = { "event_id": event_id, - "internal_metadata": row[1], - "json": row[2], - "format_version": row[3], - "room_version_id": row[4], - "rejected_reason": row[5], + "stream_ordering": row[1], + "internal_metadata": row[2], + "json": row[3], + "format_version": row[4], + "room_version_id": row[5], + "rejected_reason": row[6], "redactions": [], } diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index ccfbb2135e..7218191965 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - with await self._group_updates_id_gen.get_next() as next_id: + async with self._group_updates_id_gen.get_next() as next_id: res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 1d76c761a6..cc538c5c10 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -24,9 +24,7 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(MediaRepositoryBackgroundUpdateStore, self).__init__( - database, db_conn, hs - ) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( update_name="local_media_repository_url_idx", @@ -94,7 +92,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" def __init__(self, database: DatabasePool, db_conn, hs): - super(MediaRepositoryStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 686052bd83..0acf0617ca 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -12,17 +12,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import typing -from collections import Counter +import calendar +import logging +import time +from typing import Dict -from synapse.metrics import BucketCollector -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics import GaugeBucketCollector +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) +logger = logging.getLogger(__name__) + +# Collect metrics on the number of forward extremities that exist. +_extremities_collecter = GaugeBucketCollector( + "synapse_forward_extremities", + "Number of rooms on the server with the given number of forward extremities" + " or fewer", + buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500], +) + +# we also expose metrics on the "number of excess extremity events", which is +# (E-1)*N, where E is the number of extremities and N is the number of state +# events in the room. This is an approximation to the number of state events +# we could remove from state resolution by reducing the graph to a single +# forward extremity. +_excess_state_events_collecter = GaugeBucketCollector( + "synapse_excess_extremity_events", + "Number of rooms on the server with the given number of excess extremity " + "events, or fewer", + buckets=[0] + [1 << n for n in range(12)], +) + class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): """Functions to pull various metrics from the DB, for e.g. phone home @@ -32,40 +56,37 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - # Collect metrics on the number of forward extremities that exist. - # Counter of number of extremities to count - self._current_forward_extremities_amount = ( - Counter() - ) # type: typing.Counter[int] - - BucketCollector( - "synapse_forward_extremities", - lambda: self._current_forward_extremities_amount, - buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], - ) - # Read the extrems every 60 minutes - def read_forward_extremities(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "read_forward_extremities", self._read_forward_extremities - ) + if hs.config.run_background_tasks: + self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000) - hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) + # Used in _generate_user_daily_visits to keep track of progress + self._last_user_visit_update = self._get_start_of_day() + @wrap_as_background_process("read_forward_extremities") async def _read_forward_extremities(self): def fetch(txn): txn.execute( """ - select count(*) c from event_forward_extremities - group by room_id + SELECT t1.c, t2.c + FROM ( + SELECT room_id, COUNT(*) c FROM event_forward_extremities + GROUP BY room_id + ) t1 LEFT JOIN ( + SELECT room_id, COUNT(*) c FROM current_state_events + GROUP BY room_id + ) t2 ON t1.room_id = t2.room_id """ ) return txn.fetchall() res = await self.db_pool.runInteraction("read_forward_extremities", fetch) - self._current_forward_extremities_amount = Counter([x[0] for x in res]) + + _extremities_collecter.update_data(x[0] for x in res) + + _excess_state_events_collecter.update_data( + (x[0] - 1) * x[1] for x in res if x[1] + ) async def count_daily_messages(self): """ @@ -120,3 +141,190 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): return count return await self.db_pool.runInteraction("count_daily_active_rooms", _count) + + async def count_daily_users(self) -> int: + """ + Counts the number of users who used this homeserver in the last 24 hours. + """ + yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) + return await self.db_pool.runInteraction( + "count_daily_users", self._count_users, yesterday + ) + + async def count_monthly_users(self) -> int: + """ + Counts the number of users who used this homeserver in the last 30 days. + Note this method is intended for phonehome metrics only and is different + from the mau figure in synapse.storage.monthly_active_users which, + amongst other things, includes a 3 day grace period before a user counts. + """ + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + return await self.db_pool.runInteraction( + "count_monthly_users", self._count_users, thirty_days_ago + ) + + def _count_users(self, txn, time_from): + """ + Returns number of users seen in the past time_from period + """ + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM user_ips + WHERE last_seen > ? + GROUP BY user_id + ) u + """ + txn.execute(sql, (time_from,)) + (count,) = txn.fetchone() + return count + + async def count_r30_users(self) -> Dict[str, int]: + """ + Counts the number of 30 day retained users, defined as:- + * Users who have created their accounts more than 30 days ago + * Where last seen at most 30 days ago + * Where account creation and last_seen are > 30 days apart + + Returns: + A mapping of counts globally as well as broken out by platform. + """ + + def _count_r30_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + thirty_days_ago_in_secs = now - thirty_days_in_secs + + sql = """ + SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT + users.name, platform, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM user_ips + ) uip + ON users.name = uip.user_id + AND users.appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, platform, users.creation_ts + ) u GROUP BY platform + """ + + results = {} + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + for row in txn: + if row[0] == "unknown": + pass + results[row[0]] = row[1] + + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT users.name, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen + FROM user_ips + ) uip + ON users.name = uip.user_id + AND appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, users.creation_ts + ) u + """ + + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + (count,) = txn.fetchone() + results["all"] = count + + return results + + return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) + + def _get_start_of_day(self): + """ + Returns millisecond unixtime for start of UTC day. + """ + now = time.gmtime() + today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) + return today_start * 1000 + + @wrap_as_background_process("generate_user_daily_visits") + async def generate_user_daily_visits(self) -> None: + """ + Generates daily visit data for use in cohort/ retention analysis + """ + + def _generate_user_daily_visits(txn): + logger.info("Calling _generate_user_daily_visits") + today_start = self._get_start_of_day() + a_day_in_milliseconds = 24 * 60 * 60 * 1000 + now = self._clock.time_msec() + + sql = """ + INSERT INTO user_daily_visits (user_id, device_id, timestamp) + SELECT u.user_id, u.device_id, ? + FROM user_ips AS u + LEFT JOIN ( + SELECT user_id, device_id, timestamp FROM user_daily_visits + WHERE timestamp = ? + ) udv + ON u.user_id = udv.user_id AND u.device_id=udv.device_id + INNER JOIN users ON users.name=u.user_id + WHERE last_seen > ? AND last_seen <= ? + AND udv.timestamp IS NULL AND users.is_guest=0 + AND users.appservice_id IS NULL + GROUP BY u.user_id, u.device_id + """ + + # This means that the day has rolled over but there could still + # be entries from the previous day. There is an edge case + # where if the user logs in at 23:59 and overwrites their + # last_seen at 00:01 then they will not be counted in the + # previous day's stats - it is important that the query is run + # often to minimise this case. + if today_start > self._last_user_visit_update: + yesterday_start = today_start - a_day_in_milliseconds + txn.execute( + sql, + ( + yesterday_start, + yesterday_start, + self._last_user_visit_update, + today_start, + ), + ) + self._last_user_visit_update = today_start + + txn.execute( + sql, (today_start, today_start, self._last_user_visit_update, now) + ) + # Update _last_user_visit_update to now. The reason to do this + # rather just clamping to the beginning of the day is to limit + # the size of the join - meaning that the query can be run more + # frequently + self._last_user_visit_update = now + + await self.db_pool.runInteraction( + "generate_user_daily_visits", _generate_user_daily_visits + ) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 1d793d3deb..c66f558567 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -28,10 +28,13 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs + self._limit_usage_by_mau = hs.config.limit_usage_by_mau + self._max_mau_value = hs.config.max_mau_value + @cached(num_args=0) async def get_monthly_active_count(self) -> int: """Generates current count of monthly active users @@ -41,7 +44,14 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """ def _count_users(txn): - sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" + # Exclude app service users + sql = """ + SELECT COALESCE(count(*), 0) + FROM monthly_active_users + LEFT JOIN users + ON monthly_active_users.user_id=users.name + WHERE (users.appservice_id IS NULL OR users.appservice_id = ''); + """ txn.execute(sql) (count,) = txn.fetchone() return count @@ -117,60 +127,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): desc="user_last_seen_monthly_active", ) - -class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) - - self._limit_usage_by_mau = hs.config.limit_usage_by_mau - self._mau_stats_only = hs.config.mau_stats_only - self._max_mau_value = hs.config.max_mau_value - - # Do not add more reserved users than the total allowable number - # cur = LoggingTransaction( - self.db_pool.new_transaction( - db_conn, - "initialise_mau_threepids", - [], - [], - self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[: self._max_mau_value], - ) - - def _initialise_reserved_users(self, txn, threepids): - """Ensures that reserved threepids are accounted for in the MAU table, should - be called on start up. - - Args: - txn (cursor): - threepids (list[dict]): List of threepid dicts to reserve - """ - - # XXX what is this function trying to achieve? It upserts into - # monthly_active_users for each *registered* reserved mau user, but why? - # - # - shouldn't there already be an entry for each reserved user (at least - # if they have been active recently)? - # - # - if it's important that the timestamp is kept up to date, why do we only - # run this at startup? - - for tp in threepids: - user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) - - if user_id: - is_support = self.is_support_user_txn(txn, user_id) - if not is_support: - # We do this manually here to avoid hitting #6791 - self.db_pool.simple_upsert_txn( - txn, - table="monthly_active_users", - keyvalues={"user_id": user_id}, - values={"timestamp": int(self._clock.time_msec())}, - ) - else: - logger.warning("mau limit reserved threepid %s not found in db" % tp) - async def reap_monthly_active_users(self): """Cleans out monthly active user table to ensure that no stale entries exist. @@ -250,6 +206,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): "reap_monthly_active_users", _reap_users, reserved_users ) + +class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self._mau_stats_only = hs.config.mau_stats_only + + # Do not add more reserved users than the total allowable number + self.db_pool.new_transaction( + db_conn, + "initialise_mau_threepids", + [], + [], + self._initialise_reserved_users, + hs.config.mau_limits_reserved_threepids[: self._max_mau_value], + ) + + def _initialise_reserved_users(self, txn, threepids): + """Ensures that reserved threepids are accounted for in the MAU table, should + be called on start up. + + Args: + txn (cursor): + threepids (list[dict]): List of threepid dicts to reserve + """ + + # XXX what is this function trying to achieve? It upserts into + # monthly_active_users for each *registered* reserved mau user, but why? + # + # - shouldn't there already be an entry for each reserved user (at least + # if they have been active recently)? + # + # - if it's important that the timestamp is kept up to date, why do we only + # run this at startup? + + for tp in threepids: + user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) + + if user_id: + is_support = self.is_support_user_txn(txn, user_id) + if not is_support: + # We do this manually here to avoid hitting #6791 + self.db_pool.simple_upsert_txn( + txn, + table="monthly_active_users", + keyvalues={"user_id": user_id}, + values={"timestamp": int(self._clock.time_msec())}, + ) + else: + logger.warning("mau limit reserved threepid %s not found in db" % tp) + async def upsert_monthly_active_user(self, user_id: str) -> None: """Updates or inserts the user into the monthly active user table, which is used to track the current MAU usage of the server diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index c9f655dfb7..dbbb99cb95 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): async def update_presence(self, presence_states): - stream_ordering_manager = await self._presence_id_gen.get_next_mult( + stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states) ) - with stream_ordering_manager as stream_orderings: + async with stream_ordering_manager as stream_orderings: await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index d7a03cbf7d..ecfc6717b3 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -42,17 +42,17 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): The set of state groups that are referenced by deleted events. """ + parsed_token = await RoomStreamToken.parse(self, token) + return await self.db_pool.runInteraction( "purge_history", self._purge_history_txn, room_id, - token, + parsed_token, delete_local_events, ) - def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): - token = RoomStreamToken.parse(token_str) - + def _purge_history_txn(self, txn, room_id, token, delete_local_events): # Tables that should be pruned: # event_auth # event_backward_extremities diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 9790a31998..711d5aa23d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -61,6 +61,8 @@ def _load_rules(rawrules, enabled_map, use_new_defaults=False): return rules +# The ABCMeta metaclass ensures that it cannot be instantiated without +# the abstract methods being implemented. class PushRulesWorkerStore( ApplicationServiceWorkerStore, ReceiptsWorkerStore, @@ -68,17 +70,14 @@ class PushRulesWorkerStore( RoomMemberWorkerStore, EventsWorkerStore, SQLBaseStore, + metaclass=abc.ABCMeta, ): """This is an abstract base class where subclasses must implement `get_max_push_rules_stream_id` which can be called in the initializer. """ - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - def __init__(self, database: DatabasePool, db_conn, hs): - super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: self._push_rules_stream_id_gen = StreamIdGenerator( @@ -339,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore): ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() if before or after: @@ -586,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -617,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore): Raises: NotFoundError if the rule does not exist. """ - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( "_set_push_rule_enabled_txn", @@ -755,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore): data={"actions": actions_json}, ) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index c388468273..df8609b97b 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore): last_stream_ordering, profile_tag="", ) -> None: - with await self._pushers_id_gen.get_next() as stream_id: + async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry await self.db_pool.simple_upsert( @@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore): }, ) - with await self._pushers_id_gen.get_next() as stream_id: + async with self._pushers_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "delete_pusher", delete_pusher_txn, stream_id ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 4a0d5a320e..c79ddff680 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -31,17 +31,15 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -class ReceiptsWorkerStore(SQLBaseStore): +# The ABCMeta metaclass ensures that it cannot be instantiated without +# the abstract methods being implemented. +class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): """This is an abstract base class where subclasses must implement `get_max_receipt_stream_id` which can be called in the initializer. """ - # This ABCMeta metaclass ensures that we cannot be instantiated without - # the abstract methods being implemented. - __metaclass__ = abc.ABCMeta - def __init__(self, database: DatabasePool, db_conn, hs): - super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() @@ -388,7 +386,7 @@ class ReceiptsStore(ReceiptsWorkerStore): db_conn, "receipts_linearized", "stream_id" ) - super(ReceiptsStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() @@ -526,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "insert_receipt_conv", graph_to_linear ) - with await self._receipts_id_gen.get_next() as stream_id: + async with self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 01f20c03c2..a85867936f 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,14 +14,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import re from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.types import Cursor @@ -36,15 +38,33 @@ logger = logging.getLogger(__name__) class RegistrationWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config self.clock = hs.get_clock() + # Note: we don't check this sequence for consistency as we'd have to + # call `find_max_generated_user_id_localpart` each time, which is + # expensive if there are many entries. self._user_id_seq = build_sequence_generator( database.engine, find_max_generated_user_id_localpart, "user_id_seq", ) + self._account_validity = hs.config.account_validity + if hs.config.run_background_tasks and self._account_validity.enabled: + self._clock.call_later( + 0.0, + run_as_background_process, + "account_validity_set_expiration_dates", + self._set_expiration_date_when_missing, + ) + + # Create a background job for culling expired 3PID validity tokens + if hs.config.run_background_tasks: + self.clock.looping_call( + self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS + ) + @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: return await self.db_pool.simple_select_one( @@ -116,6 +136,20 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_expiration_ts_for_user", ) + async def is_account_expired(self, user_id: str, current_ts: int) -> bool: + """ + Returns whether an user account is expired. + + Args: + user_id: The user's ID + current_ts: The current timestamp + + Returns: + Whether the user account has expired + """ + expiration_ts = await self.get_expiration_ts_for_user(user_id) + return expiration_ts is not None and current_ts >= expiration_ts + async def set_account_validity_for_user( self, user_id: str, @@ -379,7 +413,7 @@ class RegistrationWorkerStore(SQLBaseStore): async def get_user_by_external_id( self, auth_provider: str, external_id: str - ) -> str: + ) -> Optional[str]: """Look up a user by their external auth id Args: @@ -387,7 +421,7 @@ class RegistrationWorkerStore(SQLBaseStore): external_id: id on that system Returns: - str|None: the mxid of the user, or None if they are not known + the mxid of the user, or None if they are not known """ return await self.db_pool.simple_select_one_onecol( table="user_external_ids", @@ -761,10 +795,82 @@ class RegistrationWorkerStore(SQLBaseStore): "delete_threepid_session", delete_threepid_session_txn ) + @wrap_as_background_process("cull_expired_threepid_validation_tokens") + async def cull_expired_threepid_validation_tokens(self) -> None: + """Remove threepid validation tokens with expiry dates that have passed""" + + def cull_expired_threepid_validation_tokens_txn(txn, ts): + sql = """ + DELETE FROM threepid_validation_token WHERE + expires < ? + """ + txn.execute(sql, (ts,)) + + await self.db_pool.runInteraction( + "cull_expired_threepid_validation_tokens", + cull_expired_threepid_validation_tokens_txn, + self.clock.time_msec(), + ) + + async def _set_expiration_date_when_missing(self): + """ + Retrieves the list of registered users that don't have an expiration date, and + adds an expiration date for each of them. + """ + + def select_users_with_no_expiration_date_txn(txn): + """Retrieves the list of registered users with no expiration date from the + database, filtering out deactivated users. + """ + sql = ( + "SELECT users.name FROM users" + " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" + " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" + ) + txn.execute(sql, []) + + res = self.db_pool.cursor_to_dict(txn) + if res: + for user in res: + self.set_expiration_date_for_user_txn( + txn, user["name"], use_delta=True + ) + + await self.db_pool.runInteraction( + "get_users_with_no_expiration_date", + select_users_with_no_expiration_date_txn, + ) + + def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): + """Sets an expiration date to the account with the given user ID. + + Args: + user_id (str): User ID to set an expiration date for. + use_delta (bool): If set to False, the expiration date for the user will be + now + validity period. If set to True, this expiration date will be a + random value in the [now + period - d ; now + period] range, d being a + delta equal to 10% of the validity period. + """ + now_ms = self._clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + + if use_delta: + expiration_ts = self.rand.randrange( + expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts, + ) + + self.db_pool.simple_upsert_txn( + txn, + "account_validity", + keyvalues={"user_id": user_id}, + values={"expiration_ts_ms": expiration_ts, "email_sent": False}, + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.clock = hs.get_clock() self.config = hs.config @@ -892,30 +998,10 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationStore(RegistrationBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RegistrationStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) - self._account_validity = hs.config.account_validity self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors - if self._account_validity.enabled: - self._clock.call_later( - 0.0, - run_as_background_process, - "account_validity_set_expiration_dates", - self._set_expiration_date_when_missing, - ) - - # Create a background job for culling expired 3PID validity tokens - def start_cull(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "cull_expired_threepid_validation_tokens", - self.cull_expired_threepid_validation_tokens, - ) - - hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) - async def add_access_token_to_user( self, user_id: str, @@ -947,6 +1033,36 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="add_access_token_to_user", ) + def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: + old_device_id = self.db_pool.simple_select_one_onecol_txn( + txn, "access_tokens", {"token": token}, "device_id" + ) + + self.db_pool.simple_update_txn( + txn, "access_tokens", {"token": token}, {"device_id": device_id} + ) + + self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,)) + + return old_device_id + + async def set_device_for_access_token(self, token: str, device_id: str) -> str: + """Sets the device ID associated with an access token. + + Args: + token: The access token to modify. + device_id: The new device ID. + Returns: + The old device ID associated with the access token. + """ + + return await self.db_pool.runInteraction( + "set_device_for_access_token", + self._set_device_for_access_token_txn, + token, + device_id, + ) + async def register_user( self, user_id: str, @@ -1430,22 +1546,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): start_or_continue_validation_session_txn, ) - async def cull_expired_threepid_validation_tokens(self) -> None: - """Remove threepid validation tokens with expiry dates that have passed""" - - def cull_expired_threepid_validation_tokens_txn(txn, ts): - sql = """ - DELETE FROM threepid_validation_token WHERE - expires < ? - """ - txn.execute(sql, (ts,)) - - await self.db_pool.runInteraction( - "cull_expired_threepid_validation_tokens", - cull_expired_threepid_validation_tokens_txn, - self.clock.time_msec(), - ) - async def set_user_deactivated_status( self, user_id: str, deactivated: bool ) -> None: @@ -1475,61 +1575,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) txn.call_after(self.is_guest.invalidate, (user_id,)) - async def _set_expiration_date_when_missing(self): - """ - Retrieves the list of registered users that don't have an expiration date, and - adds an expiration date for each of them. - """ - - def select_users_with_no_expiration_date_txn(txn): - """Retrieves the list of registered users with no expiration date from the - database, filtering out deactivated users. - """ - sql = ( - "SELECT users.name FROM users" - " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" - " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" - ) - txn.execute(sql, []) - - res = self.db_pool.cursor_to_dict(txn) - if res: - for user in res: - self.set_expiration_date_for_user_txn( - txn, user["name"], use_delta=True - ) - - await self.db_pool.runInteraction( - "get_users_with_no_expiration_date", - select_users_with_no_expiration_date_txn, - ) - - def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): - """Sets an expiration date to the account with the given user ID. - - Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be - now + validity period. If set to True, this expiration date will be a - random value in the [now + period - d ; now + period] range, d being a - delta equal to 10% of the validity period. - """ - now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period - - if use_delta: - expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, - expiration_ts, - ) - - self.db_pool.simple_upsert_txn( - txn, - "account_validity", - keyvalues={"user_id": user_id}, - values={"expiration_ts_ms": expiration_ts, "email_sent": False}, - ) - def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 127588ce4c..c0f2af0785 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -69,7 +69,7 @@ class RoomSortOrder(Enum): class RoomWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config @@ -192,6 +192,18 @@ class RoomWorkerStore(SQLBaseStore): "count_public_rooms", _count_public_rooms_txn ) + async def get_room_count(self) -> int: + """Retrieve the total number of rooms. + """ + + def f(txn): + sql = "SELECT count(*) FROM rooms" + txn.execute(sql) + row = txn.fetchone() + return row[0] or 0 + + return await self.db_pool.runInteraction("get_rooms", f) + async def get_largest_public_rooms( self, network_tuple: Optional[ThirdPartyInstanceID], @@ -863,7 +875,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config @@ -1074,7 +1086,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config @@ -1137,7 +1149,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "store_room_txn", store_room_txn, next_id ) @@ -1204,7 +1216,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) @@ -1284,7 +1296,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, @@ -1292,18 +1304,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) self.hs.get_notifier().on_new_replication_data() - async def get_room_count(self) -> int: - """Retrieve the total number of rooms. - """ - - def f(txn): - sql = "SELECT count(*) FROM rooms" - txn.execute(sql) - row = txn.fetchone() - return row[0] or 0 - - return await self.db_pool.runInteraction("get_rooms", f) - async def add_event_report( self, room_id: str, @@ -1328,6 +1328,101 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): desc="add_event_report", ) + async def get_event_reports_paginate( + self, + start: int, + limit: int, + direction: str = "b", + user_id: Optional[str] = None, + room_id: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], int]: + """Retrieve a paginated list of event reports + + Args: + start: event offset to begin the query from + limit: number of rows to retrieve + direction: Whether to fetch the most recent first (`"b"`) or the + oldest first (`"f"`) + user_id: search for user_id. Ignored if user_id is None + room_id: search for room_id. Ignored if room_id is None + Returns: + event_reports: json list of event reports + count: total number of event reports matching the filter criteria + """ + + def _get_event_reports_paginate_txn(txn): + filters = [] + args = [] + + if user_id: + filters.append("er.user_id LIKE ?") + args.extend(["%" + user_id + "%"]) + if room_id: + filters.append("er.room_id LIKE ?") + args.extend(["%" + room_id + "%"]) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + sql = """ + SELECT COUNT(*) as total_event_reports + FROM event_reports AS er + {} + """.format( + where_clause + ) + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = """ + SELECT + er.id, + er.received_ts, + er.room_id, + er.event_id, + er.user_id, + er.reason, + er.content, + events.sender, + room_aliases.room_alias, + event_json.json AS event_json + FROM event_reports AS er + LEFT JOIN room_aliases + ON room_aliases.room_id = er.room_id + JOIN events + ON events.event_id = er.event_id + JOIN event_json + ON event_json.event_id = er.event_id + {where_clause} + ORDER BY er.received_ts {order} + LIMIT ? + OFFSET ? + """.format( + where_clause=where_clause, order=order, + ) + + args += [limit, start] + txn.execute(sql, args) + event_reports = self.db_pool.cursor_to_dict(txn) + + if count > 0: + for row in event_reports: + try: + row["content"] = db_to_json(row["content"]) + row["event_json"] = db_to_json(row["event_json"]) + except Exception: + continue + + return event_reports, count + + return await self.db_pool.runInteraction( + "get_event_reports_paginate", _get_event_reports_paginate_txn + ) + def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 91a8b43da3..20fcdaa529 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set @@ -22,12 +21,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import ( - LoggingTransaction, - SQLBaseStore, - db_to_json, - make_in_list_sql_clause, -) +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine @@ -37,7 +31,7 @@ from synapse.storage.roommember import ( ProfileInfo, RoomsForUser, ) -from synapse.types import Collection, get_domain_from_id +from synapse.types import Collection, PersistedEventPosition, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -55,21 +49,22 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Is the current_state_events.membership up to date? Or is the # background update still running? self._current_state_events_membership_up_to_date = False - txn = LoggingTransaction( - db_conn.cursor(), - name="_check_safe_current_state_events_membership_updated", - database_engine=self.database_engine, + txn = db_conn.cursor( + txn_name="_check_safe_current_state_events_membership_updated" ) self._check_safe_current_state_events_membership_updated_txn(txn) txn.close() - if self.hs.config.metrics_flags.known_servers: + if ( + self.hs.config.run_background_tasks + and self.hs.config.metrics_flags.known_servers + ): self._known_servers_count = 1 self.hs.get_clock().looping_call( run_as_background_process, @@ -387,7 +382,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # for rooms the server is participating in. if self._current_state_events_membership_up_to_date: sql = """ - SELECT room_id, e.stream_ordering + SELECT room_id, e.instance_name, e.stream_ordering FROM current_state_events AS c INNER JOIN events AS e USING (room_id, event_id) WHERE @@ -397,7 +392,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ else: sql = """ - SELECT room_id, e.stream_ordering + SELECT room_id, e.instance_name, e.stream_ordering FROM current_state_events AS c INNER JOIN room_memberships AS m USING (room_id, event_id) INNER JOIN events AS e USING (room_id, event_id) @@ -408,7 +403,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (user_id, Membership.JOIN)) - return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) + return frozenset( + GetRoomsForUserWithStreamOrdering( + room_id, PersistedEventPosition(instance, stream_id) + ) + for room_id, instance, stream_id in txn + ) async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] @@ -819,7 +819,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) @@ -973,7 +973,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomMemberStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def forget(self, user_id: str, room_id: str) -> None: """Indicate that user_id wishes to discard history for room_id.""" diff --git a/synapse/storage/databases/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py index 3edfcfd783..45b846e6a7 100644 --- a/synapse/storage/databases/main/schema/delta/20/pushers.py +++ b/synapse/storage/databases/main/schema/delta/20/pushers.py @@ -66,16 +66,15 @@ def run_create(cur, database_engine, *args, **kwargs): row[8] = bytes(row[8]).decode("utf-8") row[11] = bytes(row[11]).decode("utf-8") cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_token, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + ) values (%s) + """ + % (",".join(["?" for _ in range(len(row))])), row, ) count += 1 diff --git a/synapse/storage/databases/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py index ee675e71ff..21f57825d4 100644 --- a/synapse/storage/databases/main/schema/delta/25/fts.py +++ b/synapse/storage/databases/main/schema/delta/25/fts.py @@ -71,8 +71,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_search", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py index b7972cfa8e..1c6058063f 100644 --- a/synapse/storage/databases/main/schema/delta/27/ts.py +++ b/synapse/storage/databases/main/schema/delta/27/ts.py @@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_origin_server_ts", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py index b42c02710a..7f08fabe9f 100644 --- a/synapse/storage/databases/main/schema/delta/30/as_users.py +++ b/synapse/storage/databases/main/schema/delta/30/as_users.py @@ -59,9 +59,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n)) for chunk in user_chunks: cur.execute( - database_engine.convert_param_style( - "UPDATE users SET appservice_id = ? WHERE name IN (%s)" - % (",".join("?" for _ in chunk),) - ), + "UPDATE users SET appservice_id = ? WHERE name IN (%s)" + % (",".join("?" for _ in chunk),), [as_id] + chunk, ) diff --git a/synapse/storage/databases/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py index 9bb504aad5..5be81c806a 100644 --- a/synapse/storage/databases/main/schema/delta/31/pushers.py +++ b/synapse/storage/databases/main/schema/delta/31/pushers.py @@ -65,16 +65,15 @@ def run_create(cur, database_engine, *args, **kwargs): row = list(row) row[12] = token_to_stream_ordering(row[12]) cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_stream_ordering, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_stream_ordering, last_success, + failing_since + ) values (%s) + """ + % (",".join(["?" for _ in range(len(row))])), row, ) count += 1 diff --git a/synapse/storage/databases/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py index 63b757ade6..b84c844e3a 100644 --- a/synapse/storage/databases/main/schema/delta/31/search_update.py +++ b/synapse/storage/databases/main/schema/delta/31/search_update.py @@ -55,8 +55,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_search_order", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py index a3e81eeac7..e928c66a8f 100644 --- a/synapse/storage/databases/main/schema/delta/33/event_fields.py +++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py @@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_fields_sender_url", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py index a26057dfb6..ad875c733a 100644 --- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py @@ -23,8 +23,5 @@ def run_create(cur, database_engine, *args, **kwargs): def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute( - database_engine.convert_param_style( - "UPDATE remote_media_cache SET last_access_ts = ?" - ), - (int(time.time() * 1000),), + "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),), ) diff --git a/synapse/storage/databases/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql index 5e29c1da19..ccf287971c 100644 --- a/synapse/storage/databases/main/schema/delta/56/event_labels.sql +++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql @@ -13,7 +13,7 @@ * limitations under the License. */ --- room_id and topoligical_ordering are denormalised from the events table in order to +-- room_id and topological_ordering are denormalised from the events table in order to -- make the index work. CREATE TABLE IF NOT EXISTS event_labels ( event_id TEXT, diff --git a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py index 1de8b54961..bb7296852a 100644 --- a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py +++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py @@ -1,6 +1,8 @@ import logging +from io import StringIO from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import execute_statements_from_stream logger = logging.getLogger(__name__) @@ -46,7 +48,4 @@ def run_create(cur, database_engine, *args, **kwargs): select_clause, ) - if isinstance(database_engine, PostgresEngine): - cur.execute(sql) - else: - cur.executescript(sql) + execute_statements_from_stream(cur, StringIO(sql)) diff --git a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py index 63b5acdcf7..44917f0a2e 100644 --- a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py +++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py @@ -68,7 +68,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): INNER JOIN room_memberships AS r USING (event_id) WHERE type = 'm.room.member' AND state_key LIKE ? """ - sql = database_engine.convert_param_style(sql) cur.execute(sql, ("%:" + config.server_name,)) cur.execute( diff --git a/synapse/storage/databases/main/schema/delta/58/11dehydration.sql b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql new file mode 100644 index 0000000000..7851a0a825 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql @@ -0,0 +1,20 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS dehydrated_devices( + user_id TEXT NOT NULL PRIMARY KEY, + device_id TEXT NOT NULL, + device_data TEXT NOT NULL -- JSON-encoded client-defined data +); diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql new file mode 100644 index 0000000000..4ed981dbf8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql @@ -0,0 +1,24 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json ( + user_id TEXT NOT NULL, -- The user this fallback key is for. + device_id TEXT NOT NULL, -- The device this fallback key is for. + algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for. + key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. + key_json TEXT NOT NULL, -- The key as a JSON blob. + used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not. + CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm) +); diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres index 97c1e6a0c5..c31f9af82a 100644 --- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres @@ -21,6 +21,8 @@ SELECT setval('events_stream_seq', ( CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq; +-- If the server has never backfilled a room then doing `-MIN(...)` will give +-- a negative result, hence why we do `GREATEST(...)` SELECT setval('events_backfill_stream_seq', ( - SELECT COALESCE(-MIN(stream_ordering), 1) FROM events + SELECT GREATEST(COALESCE(-MIN(stream_ordering), 1), 1) FROM events )); diff --git a/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql new file mode 100644 index 0000000000..985fd949a2 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE stream_positions ( + stream_name TEXT NOT NULL, + instance_name TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name); diff --git a/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres new file mode 100644 index 0000000000..841186b826 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres @@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A unique and immutable mapping between instance name and an integer ID. This +-- lets us refer to instances via a small ID in e.g. stream tokens, without +-- having to encode the full name. +CREATE TABLE IF NOT EXISTS instance_map ( + instance_id SERIAL PRIMARY KEY, + instance_name TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name); diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index f01cf2fd02..e34fce6281 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -89,7 +89,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" def __init__(self, database: DatabasePool, db_conn, hs): - super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) if not hs.config.enable_search: return @@ -342,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): class SearchStore(SearchBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SearchStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 5c6168e301..3c1e33819b 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -56,7 +56,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def get_room_version(self, room_id: str) -> RoomVersion: """Get the room_version of a given room @@ -320,7 +320,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" def __init__(self, database: DatabasePool, db_conn, hs): - super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -506,4 +506,4 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(StateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 9c1bf3c289..bc8e78e1f1 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -62,7 +62,7 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} class StatsStore(StateDeltasStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(StatsStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.server_name = hs.hostname self.clock = self.hs.get_clock() @@ -211,6 +211,7 @@ class StatsStore(StateDeltasStore): * topic * avatar * canonical_alias + * guest_access A is_federatable key can also be included with a boolean value. @@ -235,6 +236,7 @@ class StatsStore(StateDeltasStore): "topic", "avatar", "canonical_alias", + "guest_access", ): field = fields.get(col, sentinel) if field is not sentinel and (not isinstance(field, str) or "\0" in field): diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 2e95518752..e3b9ff5ca6 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -35,7 +35,6 @@ what sort order was used: - topological tokems: "t%d-%d", where the integers map to the topological and stream ordering columns respectively. """ - import abc import logging from collections import namedtuple @@ -54,7 +53,9 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.types import Collection, RoomStreamToken +from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.types import Collection, PersistedEventPosition, RoomStreamToken +from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache if TYPE_CHECKING: @@ -209,6 +210,55 @@ def _make_generic_sql_bound( ) +def _filter_results( + lower_token: Optional[RoomStreamToken], + upper_token: Optional[RoomStreamToken], + instance_name: str, + topological_ordering: int, + stream_ordering: int, +) -> bool: + """Returns True if the event persisted by the given instance at the given + topological/stream_ordering falls between the two tokens (taking a None + token to mean unbounded). + + Used to filter results from fetching events in the DB against the given + tokens. This is necessary to handle the case where the tokens include + position maps, which we handle by fetching more than necessary from the DB + and then filtering (rather than attempting to construct a complicated SQL + query). + """ + + event_historical_tuple = ( + topological_ordering, + stream_ordering, + ) + + if lower_token: + if lower_token.topological is not None: + # If these are historical tokens we compare the `(topological, stream)` + # tuples. + if event_historical_tuple <= lower_token.as_historical_tuple(): + return False + + else: + # If these are live tokens we compare the stream ordering against the + # writers stream position. + if stream_ordering <= lower_token.get_stream_pos_for_instance( + instance_name + ): + return False + + if upper_token: + if upper_token.topological is not None: + if upper_token.as_historical_tuple() < event_historical_tuple: + return False + else: + if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering: + return False + + return True + + def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create @@ -259,16 +309,14 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: return " AND ".join(clauses), args -class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): +class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): """This is an abstract base class where subclasses must implement `get_room_max_stream_ordering` and `get_room_min_stream_ordering` which can be called in the initializer. """ - __metaclass__ = abc.ABCMeta - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): - super(StreamWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() self._send_federation = hs.should_send_federation() @@ -307,6 +355,33 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def get_room_min_stream_ordering(self) -> int: raise NotImplementedError() + def get_room_max_token(self) -> RoomStreamToken: + """Get a `RoomStreamToken` that marks the current maximum persisted + position of the events stream. Useful to get a token that represents + "now". + + The token returned is a "live" token that may have an instance_map + component. + """ + + min_pos = self._stream_id_gen.get_current_token() + + positions = {} + if isinstance(self._stream_id_gen, MultiWriterIdGenerator): + # The `min_pos` is the minimum position that we know all instances + # have finished persisting to, so we only care about instances whose + # positions are ahead of that. (Instance positions can be behind the + # min position as there are times we can work out that the minimum + # position is ahead of the naive minimum across all current + # positions. See MultiWriterIdGenerator for details) + positions = { + i: p + for i, p in self._stream_id_gen.get_positions().items() + if p > min_pos + } + + return RoomStreamToken(None, min_pos, positions) + async def get_room_events_stream_for_rooms( self, room_ids: Collection[str], @@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if from_key == to_key: return [], from_key - from_id = from_key.stream - to_id = to_key.stream - - has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) + has_changed = self._events_stream_cache.has_entity_changed( + room_id, from_key.stream + ) if not has_changed: return [], from_key def f(txn): - sql = ( - "SELECT event_id, stream_ordering FROM events WHERE" - " room_id = ?" - " AND not outlier" - " AND stream_ordering > ? AND stream_ordering <= ?" - " ORDER BY stream_ordering %s LIMIT ?" - ) % (order,) - txn.execute(sql, (room_id, from_id, to_id, limit)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + sql = """ + SELECT event_id, instance_name, topological_ordering, stream_ordering + FROM events + WHERE + room_id = ? + AND not outlier + AND stream_ordering > ? AND stream_ordering <= ? + ORDER BY stream_ordering %s LIMIT ? + """ % ( + order, + ) + txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit)) + + rows = [ + _EventDictReturn(event_id, None, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + from_key, + to_key, + instance_name, + topological_ordering, + stream_ordering, + ) + ][:limit] return rows rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) @@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): [r.event_id for r in rows], get_prev_content=True ) - self._set_before_and_after(ret, rows, topo_order=from_id is None) + self._set_before_and_after(ret, rows, topo_order=False) if order.lower() == "desc": ret.reverse() @@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: - from_id = from_key.stream - to_id = to_key.stream - if from_key == to_key: return [] - if from_id: + if from_key: has_changed = self._membership_stream_cache.has_entity_changed( - user_id, int(from_id) + user_id, int(from_key.stream) ) if not has_changed: return [] def f(txn): - sql = ( - "SELECT m.event_id, stream_ordering FROM events AS e," - " room_memberships AS m" - " WHERE e.event_id = m.event_id" - " AND m.user_id = ?" - " AND e.stream_ordering > ? AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - ) - txn.execute(sql, (user_id, from_id, to_id)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + sql = """ + SELECT m.event_id, instance_name, topological_ordering, stream_ordering + FROM events AS e, room_memberships AS m + WHERE e.event_id = m.event_id + AND m.user_id = ? + AND e.stream_ordering > ? AND e.stream_ordering <= ? + ORDER BY e.stream_ordering ASC + """ + txn.execute(sql, (user_id, min_from_id, max_to_id,)) + + rows = [ + _EventDictReturn(event_id, None, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + from_key, + to_key, + instance_name, + topological_ordering, + stream_ordering, + ) + ] return rows @@ -546,7 +651,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): async def get_room_event_before_stream_ordering( self, room_id: str, stream_ordering: int - ) -> Tuple[int, int, str]: + ) -> Optional[Tuple[int, int, str]]: """Gets details of the first event in a room at or before a stream ordering Args: @@ -589,19 +694,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) return "t%d-%d" % (topo, token) - async def get_stream_id_for_event(self, event_id: str) -> int: - """The stream ID for an event - Args: - event_id: The id of the event to look up a stream token for. - Raises: - StoreError if the event wasn't in the database. - Returns: - A stream ID. - """ - return await self.db_pool.runInteraction( - "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id, - ) - def get_stream_id_for_event_txn( self, txn: LoggingTransaction, event_id: str, allow_none=False, ) -> int: @@ -613,26 +705,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): allow_none=allow_none, ) - async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken: - """The stream token for an event - Args: - event_id: The id of the event to look up a stream token for. - Raises: - StoreError if the event wasn't in the database. - Returns: - A stream token. + async def get_position_for_event(self, event_id: str) -> PersistedEventPosition: + """Get the persisted position for an event """ - stream_id = await self.get_stream_id_for_event(event_id) - return RoomStreamToken(None, stream_id) + row = await self.db_pool.simple_select_one( + table="events", + keyvalues={"event_id": event_id}, + retcols=("stream_ordering", "instance_name"), + desc="get_position_for_event", + ) + + return PersistedEventPosition( + row["instance_name"] or "master", row["stream_ordering"] + ) - async def get_topological_token_for_event(self, event_id: str) -> str: + async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken: """The stream token for an event Args: event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A "t%d-%d" topological token. + A `RoomStreamToken` topological token. """ row = await self.db_pool.simple_select_one( table="events", @@ -640,25 +734,22 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ) - return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) + return RoomStreamToken(row["topological_ordering"], row["stream_ordering"]) - async def get_max_topological_token(self, room_id: str, stream_key: int) -> int: - """Get the max topological token in a room before the given stream + async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: + """Gets the topological token in a room after or at the given stream ordering. Args: room_id stream_key - - Returns: - The maximum topological token. """ sql = ( - "SELECT coalesce(max(topological_ordering), 0) FROM events" - " WHERE room_id = ? AND stream_ordering < ?" + "SELECT coalesce(MIN(topological_ordering), 0) FROM events" + " WHERE room_id = ? AND stream_ordering >= ?" ) row = await self.db_pool.execute( - "get_max_topological_token", None, sql, room_id, stream_key + "get_current_topological_token", None, sql, room_id, stream_key ) return row[0][0] if row else 0 @@ -692,8 +783,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): else: topo = None internal = event.internal_metadata - internal.before = str(RoomStreamToken(topo, stream - 1)) - internal.after = str(RoomStreamToken(topo, stream)) + internal.before = RoomStreamToken(topo, stream - 1) + internal.after = RoomStreamToken(topo, stream) internal.order = (int(topo) if topo else 0, int(stream)) async def get_events_around( @@ -980,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): else: order = "ASC" + # The bounds for the stream tokens are complicated by the fact + # that we need to handle the instance_map part of the tokens. We do this + # by fetching all events between the min stream token and the maximum + # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and + # then filtering the results. + if from_token.topological is not None: + from_bound = ( + from_token.as_historical_tuple() + ) # type: Tuple[Optional[int], int] + elif direction == "b": + from_bound = ( + None, + from_token.get_max_stream_pos(), + ) + else: + from_bound = ( + None, + from_token.stream, + ) + + to_bound = None # type: Optional[Tuple[Optional[int], int]] + if to_token: + if to_token.topological is not None: + to_bound = to_token.as_historical_tuple() + elif direction == "b": + to_bound = ( + None, + to_token.stream, + ) + else: + to_bound = ( + None, + to_token.get_max_stream_pos(), + ) + bounds = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token.as_tuple(), - to_token=to_token.as_tuple() if to_token else None, + from_token=from_bound, + to_token=to_bound, engine=self.database_engine, ) @@ -994,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): bounds += " AND " + filter_clause args.extend(filter_args) - args.append(int(limit)) + # We fetch more events as we'll filter the result set + args.append(int(limit) * 2) select_keywords = "SELECT" join_clause = "" @@ -1016,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): select_keywords += "DISTINCT" sql = """ - %(select_keywords)s event_id, topological_ordering, stream_ordering + %(select_keywords)s + event_id, instance_name, + topological_ordering, stream_ordering FROM events %(join_clause)s WHERE outlier = ? AND room_id = ? AND %(bounds)s @@ -1031,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn.execute(sql, args) - rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] + # Filter the result set. + rows = [ + _EventDictReturn(event_id, topological_ordering, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + lower_token=to_token if direction == "b" else from_token, + upper_token=from_token if direction == "b" else to_token, + instance_name=instance_name, + topological_ordering=topological_ordering, + stream_ordering=stream_ordering, + ) + ][:limit] if rows: topo = rows[-1].topological_ordering @@ -1096,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return (events, token) + @cached() + async def get_id_for_instance(self, instance_name: str) -> int: + """Get a unique, immutable ID that corresponds to the given Synapse worker instance. + """ + + def _get_id_for_instance_txn(txn): + instance_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + retcol="instance_id", + allow_none=True, + ) + if instance_id is not None: + return instance_id + + # If we don't have an entry upsert one. + # + # We could do this before the first check, and rely on the cache for + # efficiency, but each UPSERT causes the next ID to increment which + # can quickly bloat the size of the generated IDs for new instances. + self.db_pool.simple_upsert_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + values={}, + ) + + return self.db_pool.simple_select_one_onecol_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + retcol="instance_id", + ) + + return await self.db_pool.runInteraction( + "get_id_for_instance", _get_id_for_instance_txn + ) + + @cached() + async def get_name_from_instance_id(self, instance_id: int) -> str: + """Get the instance name from an ID previously returned by + `get_id_for_instance`. + """ + + return await self.db_pool.simple_select_one_onecol( + table="instance_map", + keyvalues={"instance_id": instance_id}, + retcol="instance_name", + desc="get_name_from_instance_id", + ) + class StreamStore(StreamWorkerStore): def get_room_max_stream_ordering(self) -> int: diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 96ffe26cc9..9f120d3cb6 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 091367006e..7d46090267 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -19,7 +19,7 @@ from typing import Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -43,14 +43,32 @@ _UpdateTransactionRow = namedtuple( SENTINEL = object() -class TransactionStore(SQLBaseStore): +class TransactionWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + if hs.config.run_background_tasks: + self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) + + @wrap_as_background_process("cleanup_transactions") + async def _cleanup_transactions(self) -> None: + now = self._clock.time_msec() + month_ago = now - 30 * 24 * 60 * 60 * 1000 + + def _cleanup_transactions_txn(txn): + txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) + + await self.db_pool.runInteraction( + "_cleanup_transactions", _cleanup_transactions_txn + ) + + +class TransactionStore(TransactionWorkerStore): """A collection of queries for handling PDUs. """ def __init__(self, database: DatabasePool, db_conn, hs): - super(TransactionStore, self).__init__(database, db_conn, hs) - - self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) + super().__init__(database, db_conn, hs) self._destination_retry_cache = ExpiringCache( cache_name="get_destination_retry_timings", @@ -218,6 +236,7 @@ class TransactionStore(SQLBaseStore): retry_interval = EXCLUDED.retry_interval WHERE EXCLUDED.retry_interval = 0 + OR destinations.retry_interval IS NULL OR destinations.retry_interval < EXCLUDED.retry_interval """ @@ -249,7 +268,11 @@ class TransactionStore(SQLBaseStore): "retry_interval": retry_interval, }, ) - elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: + elif ( + retry_interval == 0 + or prev_row["retry_interval"] is None + or prev_row["retry_interval"] < retry_interval + ): self.db_pool.simple_update_one_txn( txn, "destinations", @@ -261,22 +284,6 @@ class TransactionStore(SQLBaseStore): }, ) - def _start_cleanup_transactions(self): - return run_as_background_process( - "cleanup_transactions", self._cleanup_transactions - ) - - async def _cleanup_transactions(self) -> None: - now = self._clock.time_msec() - month_ago = now - 30 * 24 * 60 * 60 * 1000 - - def _cleanup_transactions_txn(txn): - txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - - await self.db_pool.runInteraction( - "_cleanup_transactions", _cleanup_transactions_txn - ) - async def store_destination_rooms_entries( self, destinations: Iterable[str], room_id: str, stream_ordering: int, ) -> None: @@ -397,7 +404,7 @@ class TransactionStore(SQLBaseStore): @staticmethod def _get_catch_up_room_event_ids_txn( - txn, destination: str, last_successful_stream_ordering: int, + txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int, ) -> List[str]: q = """ SELECT event_id FROM destination_rooms @@ -412,3 +419,60 @@ class TransactionStore(SQLBaseStore): ) event_ids = [row[0] for row in txn] return event_ids + + async def get_catch_up_outstanding_destinations( + self, after_destination: Optional[str] + ) -> List[str]: + """ + Gets at most 25 destinations which have outstanding PDUs to be caught up, + and are not being backed off from + Args: + after_destination: + If provided, all destinations must be lexicographically greater + than this one. + + Returns: + list of up to 25 destinations with outstanding catch-up. + These are the lexicographically first destinations which are + lexicographically greater than after_destination (if provided). + """ + time = self.hs.get_clock().time_msec() + + return await self.db_pool.runInteraction( + "get_catch_up_outstanding_destinations", + self._get_catch_up_outstanding_destinations_txn, + time, + after_destination, + ) + + @staticmethod + def _get_catch_up_outstanding_destinations_txn( + txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str] + ) -> List[str]: + q = """ + SELECT destination FROM destinations + WHERE destination IN ( + SELECT destination FROM destination_rooms + WHERE destination_rooms.stream_ordering > + destinations.last_successful_stream_ordering + ) + AND destination > ? + AND ( + retry_last_ts IS NULL OR + retry_last_ts + retry_interval < ? + ) + ORDER BY destination + LIMIT 25 + """ + txn.execute( + q, + ( + # everything is lexicographically greater than "" so this gives + # us the first batch of up to 25. + after_destination or "", + now_time_ms, + ), + ) + + destinations = [row[0] for row in txn] + return destinations diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 3b9211a6d2..79b7ece330 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -288,8 +288,6 @@ class UIAuthWorkerStore(SQLBaseStore): ) return [(row["user_agent"], row["ip"]) for row in rows] - -class UIAuthStore(UIAuthWorkerStore): async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ Remove sessions which were last used earlier than the expiration time. @@ -339,3 +337,7 @@ class UIAuthStore(UIAuthWorkerStore): iterable=session_ids, keyvalues={}, ) + + +class UIAuthStore(UIAuthWorkerStore): + pass diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index f2f9a5799a..5a390ff2f6 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -38,7 +38,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): SHARE_PRIVATE_WORKING_SET = 500 def __init__(self, database: DatabasePool, db_conn, hs): - super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -564,7 +564,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): SHARE_PRIVATE_WORKING_SET = 500 def __init__(self, database: DatabasePool, db_conn, hs): - super(UserDirectoryStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def remove_from_user_dir(self, user_id: str) -> None: def _remove_from_user_dir_txn(txn): diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index 2f7c95fc74..f9575b1f1f 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -100,7 +100,7 @@ class UserErasureStore(UserErasureWorkerStore): return # They are there, delete them. - self.simple_delete_one_txn( + self.db_pool.simple_delete_one_txn( txn, "erased_users", keyvalues={"user_id": user_id} ) diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 139085b672..acb24e33af 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -181,7 +181,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" def __init__(self, database: DatabasePool, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index e924f1ca3b..0e31cc811a 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -24,7 +24,7 @@ from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStor from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -52,7 +52,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(StateGroupDataStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Originally the state store used a single DictionaryCache to cache the # event IDs for the state types in a given state group to avoid hammering @@ -99,6 +99,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_seq_gen = build_sequence_generator( self.database_engine, get_max_state_group_txn, "state_group_id_seq" ) + self._state_group_seq_gen.check_consistency( + db_conn, table="state_groups", id_column="id" + ) @cached(max_entries=10000, iterable=True) async def get_state_group_delta(self, state_group): @@ -205,7 +208,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): async def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ) -> Dict[int, StateMap[str]]: + ) -> Dict[int, MutableStateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index 908cbc79e3..d6d632dc10 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -97,3 +97,20 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): """Gets a string giving the server version. For example: '3.22.0' """ ... + + @abc.abstractmethod + def in_transaction(self, conn: Connection) -> bool: + """Whether the connection is currently in a transaction. + """ + ... + + @abc.abstractmethod + def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + """Attempt to set the connections autocommit mode. + + When True queries are run outside of transactions. + + Note: This has no effect on SQLite3, so callers still need to + commit/rollback the connections. + """ + ... diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index ff39281f85..7719ac32f7 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -15,7 +15,8 @@ import logging -from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup +from synapse.storage.engines._base import BaseDatabaseEngine, IncorrectDatabaseSetup +from synapse.storage.types import Connection logger = logging.getLogger(__name__) @@ -119,6 +120,7 @@ class PostgresEngine(BaseDatabaseEngine): cursor.execute("SET synchronous_commit TO OFF") cursor.close() + db_conn.commit() @property def can_native_upsert(self): @@ -171,3 +173,9 @@ class PostgresEngine(BaseDatabaseEngine): return "%i.%i" % (numver / 10000, numver % 10000) else: return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100) + + def in_transaction(self, conn: Connection) -> bool: + return conn.status != self.module.extensions.STATUS_READY # type: ignore + + def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + return conn.set_session(autocommit=autocommit) # type: ignore diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 8a0f8c89d1..5db0f0b520 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -17,6 +17,7 @@ import threading import typing from synapse.storage.engines import BaseDatabaseEngine +from synapse.storage.types import Connection if typing.TYPE_CHECKING: import sqlite3 # noqa: F401 @@ -86,6 +87,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): db_conn.create_function("rank", 1, _rank) db_conn.execute("PRAGMA foreign_keys = ON;") + db_conn.commit() def is_deadlock(self, error): return False @@ -105,6 +107,14 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): """ return "%i.%i.%i" % self.module.sqlite_version_info + def in_transaction(self, conn: Connection) -> bool: + return conn.in_transaction # type: ignore + + def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + # Twisted doesn't let us set attributes on the connections, so we can't + # set the connection to autocommit mode. + pass + # Following functions taken from: https://github.com/coleifer/peewee diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index d89f6ed128..4d2d88d1f0 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState -from synapse.types import Collection, StateMap +from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -190,15 +190,16 @@ class EventsPersistenceStorage: self.persist_events_store = stores.persist_events self._clock = hs.get_clock() + self._instance_name = hs.get_instance_name() self.is_mine_id = hs.is_mine_id self._event_persist_queue = _EventPeristenceQueue() self._state_resolution_handler = hs.get_state_resolution_handler() async def persist_events( self, - events_and_contexts: List[Tuple[EventBase, EventContext]], + events_and_contexts: Iterable[Tuple[EventBase, EventContext]], backfilled: bool = False, - ) -> int: + ) -> RoomStreamToken: """ Write events to the database Args: @@ -228,11 +229,11 @@ class EventsPersistenceStorage: defer.gatherResults(deferreds, consumeErrors=True) ) - return self.main_store.get_current_events_token() + return self.main_store.get_room_max_token() async def persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False - ) -> Tuple[int, int]: + ) -> Tuple[PersistedEventPosition, RoomStreamToken]: """ Returns: The stream ordering of `event`, and the stream ordering of the @@ -246,8 +247,12 @@ class EventsPersistenceStorage: await make_deferred_yieldable(deferred) - max_persisted_id = self.main_store.get_current_events_token() - return (event.internal_metadata.stream_ordering, max_persisted_id) + event_stream_id = event.internal_metadata.stream_ordering + # stream ordering should have been assigned by now + assert event_stream_id + + pos = PersistedEventPosition(self._instance_name, event_stream_id) + return pos, self.main_store.get_room_max_token() def _maybe_start_persisting(self, room_id: str): async def persisting_queue(item): diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 77de025069..9e3dfe4805 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import imp import logging import os @@ -24,9 +23,10 @@ from typing import Optional, TextIO import attr from synapse.config.homeserver import HomeServerConfig +from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.engines.postgres import PostgresEngine -from synapse.storage.types import Connection, Cursor +from synapse.storage.types import Cursor from synapse.types import Collection logger = logging.getLogger(__name__) @@ -64,7 +64,7 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = ( def prepare_database( - db_conn: Connection, + db_conn: LoggingDatabaseConnection, database_engine: BaseDatabaseEngine, config: Optional[HomeServerConfig], databases: Collection[str] = ["main", "state"], @@ -86,7 +86,7 @@ def prepare_database( """ try: - cur = db_conn.cursor() + cur = db_conn.cursor(txn_name="prepare_database") # sqlite does not automatically start transactions for DDL / SELECT statements, # so we start one before running anything. This ensures that any upgrades @@ -255,9 +255,7 @@ def _setup_new_database(cur, database_engine, databases): executescript(cur, entry.absolute_path) cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded) VALUES (?,?)" - ), + "INSERT INTO schema_version (version, upgraded) VALUES (?,?)", (max_current_ver, False), ) @@ -483,17 +481,13 @@ def _upgrade_existing_database( # Mark as done. cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)" - ), + "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)", (v, relative_path), ) cur.execute("DELETE FROM schema_version") cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded) VALUES (?,?)" - ), + "INSERT INTO schema_version (version, upgraded) VALUES (?,?)", (v, True), ) @@ -529,10 +523,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) schemas to be applied """ cur.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_module_schemas WHERE module_name = ?" - ), - (modname,), + "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,), ) applied_deltas = {d for d, in cur} for (name, stream) in names_and_streams: @@ -550,9 +541,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) # Mark as done. cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)" - ), + "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)", (modname, name), ) @@ -624,9 +613,7 @@ def _get_or_create_schema_state(txn, database_engine): if current_version: txn.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_schema_deltas WHERE version >= ?" - ), + "SELECT file FROM applied_schema_deltas WHERE version >= ?", (current_version,), ) applied_deltas = [d for d, in txn] diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8c4a83a840..f152f63321 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -25,7 +25,7 @@ RoomsForUser = namedtuple( ) GetRoomsForUserWithStreamOrdering = namedtuple( - "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering") + "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos") ) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 8f68d968f0..08a69f2f96 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -20,7 +20,7 @@ import attr from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap logger = logging.getLogger(__name__) @@ -349,7 +349,7 @@ class StateGroupStorage: async def get_state_groups_ids( self, _room_id: str, event_ids: Iterable[str] - ) -> Dict[int, StateMap[str]]: + ) -> Dict[int, MutableStateMap[str]]: """Get the event IDs of all the state for the state groups for the given events Args: @@ -532,7 +532,7 @@ class StateGroupStorage: def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ) -> Awaitable[Dict[int, StateMap[str]]]: + ) -> Awaitable[Dict[int, MutableStateMap[str]]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 2d2b560e74..970bb1b9da 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -61,3 +61,9 @@ class Connection(Protocol): def rollback(self, *args, **kwargs) -> None: ... + + def __enter__(self) -> "Connection": + ... + + def __exit__(self, exc_type, exc_value, traceback) -> bool: + ... diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 1de2b91587..d7e40aaa8b 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -12,17 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import contextlib import heapq import logging import threading from collections import deque -from typing import Dict, List, Set +from contextlib import contextmanager +from typing import Dict, List, Optional, Set, Union +import attr from typing_extensions import Deque +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.types import Cursor from synapse.storage.util.sequence import PostgresSequenceGenerator logger = logging.getLogger(__name__) @@ -53,7 +55,7 @@ def _load_current_id(db_conn, table, column, step=1): """ # debug logging for https://github.com/matrix-org/synapse/issues/7968 logger.info("initialising stream generator for %s(%s)", table, column) - cur = db_conn.cursor() + cur = db_conn.cursor(txn_name="_load_current_id") if step == 1: cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) else: @@ -86,7 +88,7 @@ class StreamIdGenerator: upwards, -1 to grow downwards. Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ @@ -101,10 +103,10 @@ class StreamIdGenerator: ) self._unfinished_ids = deque() # type: Deque[int] - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -113,7 +115,7 @@ class StreamIdGenerator: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_id @@ -121,12 +123,12 @@ class StreamIdGenerator: with self._lock: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) - async def get_next_mult(self, n): + def get_next_mult(self, n): """ Usage: - with await stream_id_gen.get_next(n) as stream_ids: + async with stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -140,7 +142,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_ids @@ -149,7 +151,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or @@ -184,12 +186,16 @@ class MultiWriterIdGenerator: Args: db_conn db + stream_name: A name for the stream. instance_name: The name of this instance. table: Database table associated with stream. instance_column: Column that stores the row's writer's instance name id_column: Column that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. + writers: A list of known writers to use to populate current positions + on startup. Can be empty if nothing uses `get_current_token` or + `get_positions` (e.g. caches stream). positive: Whether the IDs are positive (true) or negative (false). When using negative IDs we go backwards from -1 to -2, -3, etc. """ @@ -198,16 +204,20 @@ class MultiWriterIdGenerator: self, db_conn, db: DatabasePool, + stream_name: str, instance_name: str, table: str, instance_column: str, id_column: str, sequence_name: str, + writers: List[str], positive: bool = True, ): self._db = db + self._stream_name = stream_name self._instance_name = instance_name self._positive = positive + self._writers = writers self._return_factor = 1 if positive else -1 # We lock as some functions may be called from DB threads. @@ -216,9 +226,7 @@ class MultiWriterIdGenerator: # Note: If we are a negative stream then we still store all the IDs as # positive to make life easier for us, and simply negate the IDs when we # return them. - self._current_positions = self._load_current_ids( - db_conn, table, instance_column, id_column - ) + self._current_positions = {} # type: Dict[str, int] # Set of local IDs that we're still processing. The current position # should be less than the minimum of this set (if not empty). @@ -251,30 +259,98 @@ class MultiWriterIdGenerator: self._sequence_gen = PostgresSequenceGenerator(sequence_name) + # We check that the table and sequence haven't diverged. + self._sequence_gen.check_consistency( + db_conn, table=table, id_column=id_column, positive=positive + ) + + # This goes and fills out the above state from the database. + self._load_current_ids(db_conn, table, instance_column, id_column) + def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str - ) -> Dict[str, int]: - # If positive stream aggregate via MAX. For negative stream use MIN - # *and* negate the result to get a positive number. - sql = """ - SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s - GROUP BY %(instance)s - """ % { - "instance": instance_column, - "id": id_column, - "table": table, - "agg": "MAX" if self._positive else "-MIN", - } + ): + cur = db_conn.cursor(txn_name="_load_current_ids") + + # Load the current positions of all writers for the stream. + if self._writers: + # We delete any stale entries in the positions table. This is + # important if we add back a writer after a long time; we want to + # consider that a "new" writer, rather than using the old stale + # entry here. + sql = """ + DELETE FROM stream_positions + WHERE + stream_name = ? + AND instance_name != ALL(?) + """ + cur.execute(sql, (self._stream_name, self._writers)) + + sql = """ + SELECT instance_name, stream_id FROM stream_positions + WHERE stream_name = ? + """ + cur.execute(sql, (self._stream_name,)) + + self._current_positions = { + instance: stream_id * self._return_factor + for instance, stream_id in cur + if instance in self._writers + } - cur = db_conn.cursor() - cur.execute(sql) + # We set the `_persisted_upto_position` to be the minimum of all current + # positions. If empty we use the max stream ID from the DB table. + min_stream_id = min(self._current_positions.values(), default=None) + + if min_stream_id is None: + # We add a GREATEST here to ensure that the result is always + # positive. (This can be a problem for e.g. backfill streams where + # the server has never backfilled). + sql = """ + SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) + FROM %(table)s + """ % { + "id": id_column, + "table": table, + "agg": "MAX" if self._positive else "-MIN", + } + cur.execute(sql) + (stream_id,) = cur.fetchone() + self._persisted_upto_position = stream_id + else: + # If we have a min_stream_id then we pull out everything greater + # than it from the DB so that we can prefill + # `_known_persisted_positions` and get a more accurate + # `_persisted_upto_position`. + # + # We also check if any of the later rows are from this instance, in + # which case we use that for this instance's current position. This + # is to handle the case where we didn't finish persisting to the + # stream positions table before restart (or the stream position + # table otherwise got out of date). + + sql = """ + SELECT %(instance)s, %(id)s FROM %(table)s + WHERE ? %(cmp)s %(id)s + """ % { + "id": id_column, + "table": table, + "instance": instance_column, + "cmp": "<=" if self._positive else ">=", + } + cur.execute(sql, (min_stream_id * self._return_factor,)) - # `cur` is an iterable over returned rows, which are 2-tuples. - current_positions = dict(cur) + self._persisted_upto_position = min_stream_id - cur.close() + with self._lock: + for (instance, stream_id,) in cur: + stream_id = self._return_factor * stream_id + self._add_persisted_position(stream_id) - return current_positions + if instance == self._instance_name: + self._current_positions[instance] = stream_id + + cur.close() def _load_next_id_txn(self, txn) -> int: return self._sequence_gen.get_next_id_txn(txn) @@ -282,59 +358,23 @@ class MultiWriterIdGenerator: def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: return self._sequence_gen.get_next_mult_txn(txn, n) - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn) - - # Assert the fetched ID is actually greater than what we currently - # believe the ID to be. If not, then the sequence and table have got - # out of sync somehow. - with self._lock: - assert self._current_positions.get(self._instance_name, 0) < next_id - self._unfinished_ids.add(next_id) - - @contextlib.contextmanager - def manager(): - try: - # Multiply by the return factor so that the ID has correct sign. - yield self._return_factor * next_id - finally: - self._mark_id_as_finished(next_id) + return _MultiWriterCtxManager(self) - return manager() - - async def get_next_mult(self, n: int): + def get_next_mult(self, n: int): """ Usage: - with await stream_id_gen.get_next_mult(5) as stream_ids: + async with stream_id_gen.get_next_mult(5) as stream_ids: # ... persist events ... """ - next_ids = await self._db.runInteraction( - "_load_next_mult_id", self._load_next_mult_id_txn, n - ) - # Assert the fetched ID is actually greater than any ID we've already - # seen. If not, then the sequence and table have got out of sync - # somehow. - with self._lock: - assert max(self._current_positions.values(), default=0) < min(next_ids) - - self._unfinished_ids.update(next_ids) - - @contextlib.contextmanager - def manager(): - try: - yield [self._return_factor * i for i in next_ids] - finally: - for i in next_ids: - self._mark_id_as_finished(i) - - return manager() + return _MultiWriterCtxManager(self, n) def get_next_txn(self, txn: LoggingTransaction): """ @@ -352,6 +392,21 @@ class MultiWriterIdGenerator: txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persited row with the correct instance name. + if self._writers: + txn.call_after( + run_as_background_process, + "MultiWriterIdGenerator._update_table", + self._db.runInteraction, + "MultiWriterIdGenerator._update_table", + self._update_stream_positions_table_txn, + ) + return self._return_factor * next_id def _mark_id_as_finished(self, next_id: int): @@ -363,7 +418,7 @@ class MultiWriterIdGenerator: self._unfinished_ids.discard(next_id) self._finished_ids.add(next_id) - new_cur = None + new_cur = None # type: Optional[int] if self._unfinished_ids: # If there are unfinished IDs then the new position will be the @@ -408,11 +463,22 @@ class MultiWriterIdGenerator: """Returns the position of the given writer. """ + # If we don't have an entry for the given instance name, we assume it's a + # new writer. + # + # For new writers we assume their initial position to be the current + # persisted up to position. This stops Synapse from doing a full table + # scan when a new writer announces itself over replication. with self._lock: - return self._return_factor * self._current_positions.get(instance_name, 0) + return self._return_factor * self._current_positions.get( + instance_name, self._persisted_upto_position + ) def get_positions(self) -> Dict[str, int]: """Get a copy of the current positon map. + + Note that this won't necessarily include all configured writers if some + writers haven't written anything yet. """ with self._lock: @@ -482,3 +548,104 @@ class MultiWriterIdGenerator: # There was a gap in seen positions, so there is nothing more to # do. break + + def _update_stream_positions_table_txn(self, txn: Cursor): + """Update the `stream_positions` table with newly persisted position. + """ + + if not self._writers: + return + + # We upsert the value, ensuring on conflict that we always increase the + # value (or decrease if stream goes backwards). + sql = """ + INSERT INTO stream_positions (stream_name, instance_name, stream_id) + VALUES (?, ?, ?) + ON CONFLICT (stream_name, instance_name) + DO UPDATE SET + stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id) + """ % { + "agg": "GREATEST" if self._positive else "LEAST", + } + + pos = (self.get_current_token_for_writer(self._instance_name),) + txn.execute(sql, (self._stream_name, self._instance_name, pos)) + + +@attr.s(slots=True) +class _AsyncCtxManagerWrapper: + """Helper class to convert a plain context manager to an async one. + + This is mainly useful if you have a plain context manager but the interface + requires an async one. + """ + + inner = attr.ib() + + async def __aenter__(self): + return self.inner.__enter__() + + async def __aexit__(self, exc_type, exc, tb): + return self.inner.__exit__(exc_type, exc, tb) + + +@attr.s(slots=True) +class _MultiWriterCtxManager: + """Async context manager returned by MultiWriterIdGenerator + """ + + id_gen = attr.ib(type=MultiWriterIdGenerator) + multiple_ids = attr.ib(type=Optional[int], default=None) + stream_ids = attr.ib(type=List[int], factory=list) + + async def __aenter__(self) -> Union[int, List[int]]: + # It's safe to run this in autocommit mode as fetching values from a + # sequence ignores transaction semantics anyway. + self.stream_ids = await self.id_gen._db.runInteraction( + "_load_next_mult_id", + self.id_gen._load_next_mult_id_txn, + self.multiple_ids or 1, + db_autocommit=True, + ) + + # Assert the fetched ID is actually greater than any ID we've already + # seen. If not, then the sequence and table have got out of sync + # somehow. + with self.id_gen._lock: + assert max(self.id_gen._current_positions.values(), default=0) < min( + self.stream_ids + ) + + self.id_gen._unfinished_ids.update(self.stream_ids) + + if self.multiple_ids is None: + return self.stream_ids[0] * self.id_gen._return_factor + else: + return [i * self.id_gen._return_factor for i in self.stream_ids] + + async def __aexit__(self, exc_type, exc, tb): + for i in self.stream_ids: + self.id_gen._mark_id_as_finished(i) + + if exc_type is not None: + return False + + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persisted row with the correct instance name. + # + # We do this in autocommit mode as a) the upsert works correctly outside + # transactions and b) reduces the amount of time the rows are locked + # for. If we don't do this then we'll often hit serialization errors due + # to the fact we default to REPEATABLE READ isolation levels. + if self.id_gen._writers: + await self.id_gen._db.runInteraction( + "MultiWriterIdGenerator._update_table", + self.id_gen._update_stream_positions_table_txn, + db_autocommit=True, + ) + + return False diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index ffc1894748..ff2d038ad2 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -13,11 +13,35 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import logging import threading from typing import Callable, List, Optional -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.storage.types import Cursor +from synapse.storage.database import LoggingDatabaseConnection +from synapse.storage.engines import ( + BaseDatabaseEngine, + IncorrectDatabaseSetup, + PostgresEngine, +) +from synapse.storage.types import Connection, Cursor + +logger = logging.getLogger(__name__) + + +_INCONSISTENT_SEQUENCE_ERROR = """ +Postgres sequence '%(seq)s' is inconsistent with associated +table '%(table)s'. This can happen if Synapse has been downgraded and +then upgraded again, or due to a bad migration. + +To fix this error, shut down Synapse (including any and all workers) +and run the following SQL: + + SELECT setval('%(seq)s', ( + %(max_id_sql)s + )); + +See docs/postgres.md for more information. +""" class SequenceGenerator(metaclass=abc.ABCMeta): @@ -28,6 +52,23 @@ class SequenceGenerator(metaclass=abc.ABCMeta): """Gets the next ID in the sequence""" ... + @abc.abstractmethod + def check_consistency( + self, + db_conn: LoggingDatabaseConnection, + table: str, + id_column: str, + positive: bool = True, + ): + """Should be called during start up to test that the current value of + the sequence is greater than or equal to the maximum ID in the table. + + This is to handle various cases where the sequence value can get out + of sync with the table, e.g. if Synapse gets rolled back to a previous + version and the rolled forwards again. + """ + ... + class PostgresSequenceGenerator(SequenceGenerator): """An implementation of SequenceGenerator which uses a postgres sequence""" @@ -45,6 +86,54 @@ class PostgresSequenceGenerator(SequenceGenerator): ) return [i for (i,) in txn] + def check_consistency( + self, + db_conn: LoggingDatabaseConnection, + table: str, + id_column: str, + positive: bool = True, + ): + txn = db_conn.cursor(txn_name="sequence.check_consistency") + + # First we get the current max ID from the table. + table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % { + "id": id_column, + "table": table, + "agg": "MAX" if positive else "-MIN", + } + + txn.execute(table_sql) + row = txn.fetchone() + if not row: + # Table is empty, so nothing to do. + txn.close() + return + + # Now we fetch the current value from the sequence and compare with the + # above. + max_stream_id = row[0] + txn.execute( + "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} + ) + last_value, is_called = txn.fetchone() + txn.close() + + # If `is_called` is False then `last_value` is actually the value that + # will be generated next, so we decrement to get the true "last value". + if not is_called: + last_value -= 1 + + if max_stream_id > last_value: + logger.warning( + "Postgres sequence %s is behind table %s: %d < %d", + last_value, + max_stream_id, + ) + raise IncorrectDatabaseSetup( + _INCONSISTENT_SEQUENCE_ERROR + % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql} + ) + GetFirstCallbackType = Callable[[Cursor], int] @@ -81,6 +170,12 @@ class LocalSequenceGenerator(SequenceGenerator): self._current_max_id += 1 return self._current_max_id + def check_consistency( + self, db_conn: Connection, table: str, id_column: str, positive: bool = True + ): + # There is nothing to do for in memory sequences + pass + def build_sequence_generator( database_engine: BaseDatabaseEngine, |