diff options
Diffstat (limited to 'synapse/storage/databases/main')
52 files changed, 2717 insertions, 1292 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 0cb12f4c61..43660ec4fb 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -15,9 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import calendar import logging -import time from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import PresenceState @@ -148,7 +146,6 @@ class DataStore( db_conn, "e2e_cross_signing_keys", "stream_id" ) - self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") @@ -268,9 +265,6 @@ class DataStore( self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() - # Used in _generate_user_daily_visits to keep track of progress - self._last_user_visit_update = self._get_start_of_day() - def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() @@ -289,7 +283,6 @@ class DataStore( " last_user_sync_ts, status_msg, currently_active FROM presence_stream" " WHERE state != ?" ) - sql = self.database_engine.convert_param_style(sql) txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE,)) @@ -301,192 +294,6 @@ class DataStore( return [UserPresenceState(**row) for row in rows] - async def count_daily_users(self) -> int: - """ - Counts the number of users who used this homeserver in the last 24 hours. - """ - yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return await self.db_pool.runInteraction( - "count_daily_users", self._count_users, yesterday - ) - - async def count_monthly_users(self) -> int: - """ - Counts the number of users who used this homeserver in the last 30 days. - Note this method is intended for phonehome metrics only and is different - from the mau figure in synapse.storage.monthly_active_users which, - amongst other things, includes a 3 day grace period before a user counts. - """ - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return await self.db_pool.runInteraction( - "count_monthly_users", self._count_users, thirty_days_ago - ) - - def _count_users(self, txn, time_from): - """ - Returns number of users seen in the past time_from period - """ - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT user_id FROM user_ips - WHERE last_seen > ? - GROUP BY user_id - ) u - """ - txn.execute(sql, (time_from,)) - (count,) = txn.fetchone() - return count - - async def count_r30_users(self) -> Dict[str, int]: - """ - Counts the number of 30 day retained users, defined as:- - * Users who have created their accounts more than 30 days ago - * Where last seen at most 30 days ago - * Where account creation and last_seen are > 30 days apart - - Returns: - A mapping of counts globally as well as broken out by platform. - """ - - def _count_r30_users(txn): - thirty_days_in_secs = 86400 * 30 - now = int(self._clock.time()) - thirty_days_ago_in_secs = now - thirty_days_in_secs - - sql = """ - SELECT platform, COALESCE(count(*), 0) FROM ( - SELECT - users.name, platform, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen, - CASE - WHEN user_agent LIKE '%%Android%%' THEN 'android' - WHEN user_agent LIKE '%%iOS%%' THEN 'ios' - WHEN user_agent LIKE '%%Electron%%' THEN 'electron' - WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' - WHEN user_agent LIKE '%%Gecko%%' THEN 'web' - ELSE 'unknown' - END - AS platform - FROM user_ips - ) uip - ON users.name = uip.user_id - AND users.appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, platform, users.creation_ts - ) u GROUP BY platform - """ - - results = {} - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - for row in txn: - if row[0] == "unknown": - pass - results[row[0]] = row[1] - - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT users.name, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen - FROM user_ips - ) uip - ON users.name = uip.user_id - AND appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, users.creation_ts - ) u - """ - - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - (count,) = txn.fetchone() - results["all"] = count - - return results - - return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) - - def _get_start_of_day(self): - """ - Returns millisecond unixtime for start of UTC day. - """ - now = time.gmtime() - today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) - return today_start * 1000 - - async def generate_user_daily_visits(self) -> None: - """ - Generates daily visit data for use in cohort/ retention analysis - """ - - def _generate_user_daily_visits(txn): - logger.info("Calling _generate_user_daily_visits") - today_start = self._get_start_of_day() - a_day_in_milliseconds = 24 * 60 * 60 * 1000 - now = self.clock.time_msec() - - sql = """ - INSERT INTO user_daily_visits (user_id, device_id, timestamp) - SELECT u.user_id, u.device_id, ? - FROM user_ips AS u - LEFT JOIN ( - SELECT user_id, device_id, timestamp FROM user_daily_visits - WHERE timestamp = ? - ) udv - ON u.user_id = udv.user_id AND u.device_id=udv.device_id - INNER JOIN users ON users.name=u.user_id - WHERE last_seen > ? AND last_seen <= ? - AND udv.timestamp IS NULL AND users.is_guest=0 - AND users.appservice_id IS NULL - GROUP BY u.user_id, u.device_id - """ - - # This means that the day has rolled over but there could still - # be entries from the previous day. There is an edge case - # where if the user logs in at 23:59 and overwrites their - # last_seen at 00:01 then they will not be counted in the - # previous day's stats - it is important that the query is run - # often to minimise this case. - if today_start > self._last_user_visit_update: - yesterday_start = today_start - a_day_in_milliseconds - txn.execute( - sql, - ( - yesterday_start, - yesterday_start, - self._last_user_visit_update, - today_start, - ), - ) - self._last_user_visit_update = today_start - - txn.execute( - sql, (today_start, today_start, self._last_user_visit_update, now) - ) - # Update _last_user_visit_update to now. The reason to do this - # rather just clamping to the beginning of the day is to limit - # the size of the join - meaning that the query can be run more - # frequently - self._last_user_visit_update = now - - await self.db_pool.runInteraction( - "generate_user_daily_visits", _generate_user_daily_visits - ) - async def get_users(self) -> List[Dict[str, Any]]: """Function to retrieve a list of users in users table. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index ef81d73573..49ee23470d 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -18,6 +18,7 @@ import abc import logging from typing import Dict, List, Optional, Tuple +from synapse.api.constants import AccountDataTypes from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator @@ -291,14 +292,18 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext ) -> bool: ignored_account_data = await self.get_global_account_data_by_type_for_user( - "m.ignored_user_list", + AccountDataTypes.IGNORED_USER_LIST, ignorer_user_id, on_invalidate=cache_context.invalidate, ) if not ignored_account_data: return False - return ignored_user_id in ignored_account_data.get("ignored_users", {}) + try: + return ignored_user_id in ignored_account_data.get("ignored_users", {}) + except TypeError: + # The type of the ignored_users field is invalid. + return False class AccountDataStore(AccountDataWorkerStore): diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 85f6b1e3fd..e550cbc866 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -15,18 +15,31 @@ # limitations under the License. import logging import re +from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple -from synapse.appservice import AppServiceTransaction +from synapse.appservice import ( + ApplicationService, + ApplicationServiceState, + AppServiceTransaction, +) from synapse.config.appservice import load_appservices +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.types import Connection +from synapse.types import JsonDict from synapse.util import json_encoder +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) -def _make_exclusive_regex(services_cache): +def _make_exclusive_regex( + services_cache: List[ApplicationService], +) -> Optional[Pattern]: # We precompile a regex constructed from all the regexes that the AS's # have registered for exclusive users. exclusive_user_regexes = [ @@ -36,17 +49,19 @@ def _make_exclusive_regex(services_cache): ] if exclusive_user_regexes: exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) - exclusive_user_regex = re.compile(exclusive_user_regex) + exclusive_user_pattern = re.compile( + exclusive_user_regex + ) # type: Optional[Pattern] else: # We handle this case specially otherwise the constructed regex # will always match - exclusive_user_regex = None + exclusive_user_pattern = None - return exclusive_user_regex + return exclusive_user_pattern class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) @@ -57,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): def get_app_services(self): return self.services_cache - def get_if_app_services_interested_in_user(self, user_id): + def get_if_app_services_interested_in_user(self, user_id: str) -> bool: """Check if the user is one associated with an app service (exclusively) """ if self.exclusive_user_regex: @@ -65,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): else: return False - def get_app_service_by_user_id(self, user_id): + def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]: """Retrieve an application service from their user ID. All application services have associated with them a particular user ID. @@ -74,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore): a user ID to an application service. Args: - user_id(str): The user ID to see if it is an application service. + user_id: The user ID to see if it is an application service. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.sender == user_id: return service return None - def get_app_service_by_token(self, token): + def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]: """Get the application service with the given appservice token. Args: - token (str): The application service token. + token: The application service token. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.token == token: return service return None - def get_app_service_by_id(self, as_id): + def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]: """Get the application service with the given appservice ID. Args: - as_id (str): The application service ID. + as_id: The application service ID. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.id == as_id: @@ -121,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore): class ApplicationServiceTransactionWorkerStore( ApplicationServiceWorkerStore, EventsWorkerStore ): - async def get_appservices_by_state(self, state): + async def get_appservices_by_state( + self, state: ApplicationServiceState + ) -> List[ApplicationService]: """Get a list of application services based on their state. Args: - state(ApplicationServiceState): The state to filter on. + state: The state to filter on. Returns: A list of ApplicationServices, which may be empty. """ @@ -142,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore( services.append(service) return services - async def get_appservice_state(self, service): + async def get_appservice_state( + self, service: ApplicationService + ) -> Optional[ApplicationServiceState]: """Get the application service state. Args: - service(ApplicationService): The service whose state to set. + service: The service whose state to set. Returns: - An ApplicationServiceState. + An ApplicationServiceState or none. """ result = await self.db_pool.simple_select_one( "application_services_state", @@ -161,26 +180,36 @@ class ApplicationServiceTransactionWorkerStore( return result.get("state") return None - async def set_appservice_state(self, service, state) -> None: + async def set_appservice_state( + self, service: ApplicationService, state: ApplicationServiceState + ) -> None: """Set the application service state. Args: - service(ApplicationService): The service whose state to set. - state(ApplicationServiceState): The connectivity state to apply. + service: The service whose state to set. + state: The connectivity state to apply. """ await self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} ) - async def create_appservice_txn(self, service, events): + async def create_appservice_txn( + self, + service: ApplicationService, + events: List[EventBase], + ephemeral: List[JsonDict], + ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service - with the given list of events. + with the given list of events. Ephemeral events are NOT persisted to the + database and are not resent if a transaction is retried. Args: - service(ApplicationService): The service who the transaction is for. - events(list<Event>): A list of events to put in the transaction. + service: The service who the transaction is for. + events: A list of persistent events to put in the transaction. + ephemeral: A list of ephemeral events to put in the transaction. + Returns: - AppServiceTransaction: A new transaction. + A new transaction. """ def _create_appservice_txn(txn): @@ -207,19 +236,22 @@ class ApplicationServiceTransactionWorkerStore( "VALUES(?,?,?)", (service.id, new_txn_id, event_ids), ) - return AppServiceTransaction(service=service, id=new_txn_id, events=events) + return AppServiceTransaction( + service=service, id=new_txn_id, events=events, ephemeral=ephemeral + ) return await self.db_pool.runInteraction( "create_appservice_txn", _create_appservice_txn ) - async def complete_appservice_txn(self, txn_id, service) -> None: + async def complete_appservice_txn( + self, txn_id: int, service: ApplicationService + ) -> None: """Completes an application service transaction. Args: - txn_id(str): The transaction ID being completed. - service(ApplicationService): The application service which was sent - this transaction. + txn_id: The transaction ID being completed. + service: The application service which was sent this transaction. """ txn_id = int(txn_id) @@ -259,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore( "complete_appservice_txn", _complete_appservice_txn ) - async def get_oldest_unsent_txn(self, service): - """Get the oldest transaction which has not been sent for this - service. + async def get_oldest_unsent_txn( + self, service: ApplicationService + ) -> Optional[AppServiceTransaction]: + """Get the oldest transaction which has not been sent for this service. Args: - service(ApplicationService): The app service to get the oldest txn. + service: The app service to get the oldest txn. Returns: An AppServiceTransaction or None. """ @@ -296,9 +329,11 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) - return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) + return AppServiceTransaction( + service=service, id=entry["txn_id"], events=events, ephemeral=[] + ) - def _get_last_txn(self, txn, service_id): + def _get_last_txn(self, txn, service_id: Optional[str]) -> int: txn.execute( "SELECT last_txn FROM application_services_state WHERE as_id=?", (service_id,), @@ -309,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore( else: return int(last_txn_id[0]) # select 'last_txn' col - async def set_appservice_last_pos(self, pos) -> None: + async def set_appservice_last_pos(self, pos: int) -> None: def set_appservice_last_pos_txn(txn): txn.execute( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) @@ -319,8 +354,10 @@ class ApplicationServiceTransactionWorkerStore( "set_appservice_last_pos", set_appservice_last_pos_txn ) - async def get_new_events_for_appservice(self, current_id, limit): - """Get all new evnets""" + async def get_new_events_for_appservice( + self, current_id: int, limit: int + ) -> Tuple[int, List[EventBase]]: + """Get all new events for an appservice""" def get_new_events_for_appservice_txn(txn): sql = ( @@ -351,6 +388,54 @@ class ApplicationServiceTransactionWorkerStore( return upper_bound, events + async def get_type_stream_id_for_appservice( + self, service: ApplicationService, type: str + ) -> int: + if type not in ("read_receipt", "presence"): + raise ValueError( + "Expected type to be a valid application stream id type, got %s" + % (type,) + ) + + def get_type_stream_id_for_appservice_txn(txn): + stream_id_type = "%s_stream_id" % type + txn.execute( + # We do NOT want to escape `stream_id_type`. + "SELECT %s FROM application_services_state WHERE as_id=?" + % stream_id_type, + (service.id,), + ) + last_stream_id = txn.fetchone() + if last_stream_id is None or last_stream_id[0] is None: # no row exists + return 0 + else: + return int(last_stream_id[0]) + + return await self.db_pool.runInteraction( + "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn + ) + + async def set_type_stream_id_for_appservice( + self, service: ApplicationService, type: str, pos: Optional[int] + ) -> None: + if type not in ("read_receipt", "presence"): + raise ValueError( + "Expected type to be a valid application stream id type, got %s" + % (type,) + ) + + def set_type_stream_id_for_appservice_txn(txn): + stream_id_type = "%s_stream_id" % type + txn.execute( + "UPDATE application_services_state SET %s = ? WHERE as_id=?" + % stream_id_type, + (pos, service.id), + ) + + await self.db_pool.runInteraction( + "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn + ) + class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): # This is currently empty due to there not being any AS storage functions diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index f211ddbaf8..3e26d5ba87 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -17,12 +17,12 @@ import logging from typing import TYPE_CHECKING from synapse.events.utils import prune_event_dict -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore -from synapse.storage.databases.main.events import encode_json from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util import json_encoder if TYPE_CHECKING: from synapse.server import HomeServer @@ -35,14 +35,13 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) - def _censor_redactions(): - return run_as_background_process( - "_censor_redactions", self._censor_redactions - ) - - if self.hs.config.redaction_retention_period is not None: - hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + if ( + hs.config.run_background_tasks + and self.hs.config.redaction_retention_period is not None + ): + hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000) + @wrap_as_background_process("_censor_redactions") async def _censor_redactions(self): """Censors all redactions older than the configured period that haven't been censored yet. @@ -105,7 +104,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase and original_event.internal_metadata.is_redacted() ): # Redaction was allowed - pruned_json = encode_json( + pruned_json = json_encoder.encode( prune_event_dict( original_event.room_version, original_event.get_dict() ) @@ -171,7 +170,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase return # Prune the event's dict then convert it to JSON. - pruned_json = encode_json( + pruned_json = json_encoder.encode( prune_event_dict(event.room_version, event.get_dict()) ) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 239c7a949c..339bd691a4 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -351,16 +351,70 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return updated -class ClientIpStore(ClientIpBackgroundUpdateStore): +class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self.user_ips_max_age = hs.config.user_ips_max_age + + if hs.config.run_background_tasks and self.user_ips_max_age: + self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) + + @wrap_as_background_process("prune_old_user_ips") + async def _prune_old_user_ips(self): + """Removes entries in user IPs older than the configured period. + """ + + if self.user_ips_max_age is None: + # Nothing to do + return + + if not await self.db_pool.updates.has_completed_background_update( + "devices_last_seen" + ): + # Only start pruning if we have finished populating the devices + # last seen info. + return + + # We do a slightly funky SQL delete to ensure we don't try and delete + # too much at once (as the table may be very large from before we + # started pruning). + # + # This works by finding the max last_seen that is less than the given + # time, but has no more than N rows before it, deleting all rows with + # a lesser last_seen time. (We COALESCE so that the sub-SELECT always + # returns exactly one row). + sql = """ + DELETE FROM user_ips + WHERE last_seen <= ( + SELECT COALESCE(MAX(last_seen), -1) + FROM ( + SELECT last_seen FROM user_ips + WHERE last_seen <= ? + ORDER BY last_seen ASC + LIMIT 5000 + ) AS u + ) + """ - self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 + timestamp = self.clock.time_msec() - self.user_ips_max_age + + def _prune_old_user_ips_txn(txn): + txn.execute(sql, (timestamp,)) + + await self.db_pool.runInteraction( + "_prune_old_user_ips", _prune_old_user_ips_txn ) - super().__init__(database, db_conn, hs) - self.user_ips_max_age = hs.config.user_ips_max_age +class ClientIpStore(ClientIpWorkerStore): + def __init__(self, database: DatabasePool, db_conn, hs): + + self.client_ip_last_seen = LruCache( + cache_name="client_ip_last_seen", keylen=4, max_size=50000 + ) + + super().__init__(database, db_conn, hs) # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) self._batch_row_update = {} @@ -372,9 +426,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): "before", "shutdown", self._update_client_ips_batch ) - if self.user_ips_max_age: - self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) - async def insert_client_ip( self, user_id, access_token, ip, user_agent, device_id, now=None ): @@ -391,7 +442,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return - self.client_ip_last_seen.prefill(key, now) + self.client_ip_last_seen.set(key, now) self._batch_row_update[key] = (user_agent, device_id, now) @@ -525,49 +576,3 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): } for (access_token, ip), (user_agent, last_seen) in results.items() ] - - @wrap_as_background_process("prune_old_user_ips") - async def _prune_old_user_ips(self): - """Removes entries in user IPs older than the configured period. - """ - - if self.user_ips_max_age is None: - # Nothing to do - return - - if not await self.db_pool.updates.has_completed_background_update( - "devices_last_seen" - ): - # Only start pruning if we have finished populating the devices - # last seen info. - return - - # We do a slightly funky SQL delete to ensure we don't try and delete - # too much at once (as the table may be very large from before we - # started pruning). - # - # This works by finding the max last_seen that is less than the given - # time, but has no more than N rows before it, deleting all rows with - # a lesser last_seen time. (We COALESCE so that the sub-SELECT always - # returns exactly one row). - sql = """ - DELETE FROM user_ips - WHERE last_seen <= ( - SELECT COALESCE(MAX(last_seen), -1) - FROM ( - SELECT last_seen FROM user_ips - WHERE last_seen <= ? - ORDER BY last_seen ASC - LIMIT 5000 - ) AS u - ) - """ - - timestamp = self.clock.time_msec() - self.user_ips_max_age - - def _prune_old_user_ips_txn(txn): - txn.execute(sql, (timestamp,)) - - await self.db_pool.runInteraction( - "_prune_old_user_ips", _prune_old_user_ips_txn - ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index fdf394c612..dfb4f87b8f 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ from synapse.logging.opentracing import ( trace, whitelisted_homeserver, ) -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -33,8 +33,9 @@ from synapse.storage.database import ( make_tuple_comparison_clause, ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key -from synapse.util import json_encoder -from synapse.util.caches.descriptors import Cache, cached, cachedList +from synapse.util import json_decoder, json_encoder +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -48,6 +49,14 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + if hs.config.run_background_tasks: + self._clock.looping_call( + self._prune_old_outbound_device_pokes, 60 * 60 * 1000 + ) + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -698,6 +707,172 @@ class DeviceWorkerStore(SQLBaseStore): _mark_remote_user_device_list_as_unsubscribed_txn, ) + async def get_dehydrated_device( + self, user_id: str + ) -> Optional[Tuple[str, JsonDict]]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + # FIXME: make sure device ID still exists in devices table + row = await self.db_pool.simple_select_one( + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcols=["device_id", "device_data"], + allow_none=True, + ) + return ( + (row["device_id"], json_decoder.decode(row["device_data"])) if row else None + ) + + def _store_dehydrated_device_txn( + self, txn, user_id: str, device_id: str, device_data: str + ) -> Optional[str]: + old_device_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcol="device_id", + allow_none=True, + ) + self.db_pool.simple_upsert_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + values={"device_id": device_id, "device_data": device_data}, + ) + return old_device_id + + async def store_dehydrated_device( + self, user_id: str, device_id: str, device_data: JsonDict + ) -> Optional[str]: + """Store a dehydrated device for a user. + + Args: + user_id: the user that we are storing the device for + device_id: the ID of the dehydrated device + device_data: the dehydrated device information + Returns: + device id of the user's previous dehydrated device, if any + """ + return await self.db_pool.runInteraction( + "store_dehydrated_device_txn", + self._store_dehydrated_device_txn, + user_id, + device_id, + json_encoder.encode(device_data), + ) + + async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool: + """Remove a dehydrated device. + + Args: + user_id: the user that the dehydrated device belongs to + device_id: the ID of the dehydrated device + """ + count = await self.db_pool.simple_delete( + "dehydrated_devices", + {"user_id": user_id, "device_id": device_id}, + desc="remove_dehydrated_device", + ) + return count >= 1 + + @wrap_as_background_process("prune_old_outbound_device_pokes") + async def _prune_old_outbound_device_pokes( + self, prune_age: int = 24 * 60 * 60 * 1000 + ) -> None: + """Delete old entries out of the device_lists_outbound_pokes to ensure + that we don't fill up due to dead servers. + + Normally, we try to send device updates as a delta since a previous known point: + this is done by setting the prev_id in the m.device_list_update EDU. However, + for that to work, we have to have a complete record of each change to + each device, which can add up to quite a lot of data. + + An alternative mechanism is that, if the remote server sees that it has missed + an entry in the stream_id sequence for a given user, it will request a full + list of that user's devices. Hence, we can reduce the amount of data we have to + store (and transmit in some future transaction), by clearing almost everything + for a given destination out of the database, and having the remote server + resync. + + All we need to do is make sure we keep at least one row for each + (user, destination) pair, to remind us to send a m.device_list_update EDU for + that user when the destination comes back. It doesn't matter which device + we keep. + """ + yesterday = self._clock.time_msec() - prune_age + + def _prune_txn(txn): + # look for (user, destination) pairs which have an update older than + # the cutoff. + # + # For each pair, we also need to know the most recent stream_id, and + # an arbitrary device_id at that stream_id. + select_sql = """ + SELECT + dlop1.destination, + dlop1.user_id, + MAX(dlop1.stream_id) AS stream_id, + (SELECT MIN(dlop2.device_id) AS device_id FROM + device_lists_outbound_pokes dlop2 + WHERE dlop2.destination = dlop1.destination AND + dlop2.user_id=dlop1.user_id AND + dlop2.stream_id=MAX(dlop1.stream_id) + ) + FROM device_lists_outbound_pokes dlop1 + GROUP BY destination, user_id + HAVING min(ts) < ? AND count(*) > 1 + """ + + txn.execute(select_sql, (yesterday,)) + rows = txn.fetchall() + + if not rows: + return + + logger.info( + "Pruning old outbound device list updates for %i users/destinations: %s", + len(rows), + shortstr((row[0], row[1]) for row in rows), + ) + + # we want to keep the update with the highest stream_id for each user. + # + # there might be more than one update (with different device_ids) with the + # same stream_id, so we also delete all but one rows with the max stream id. + delete_sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE destination = ? AND user_id = ? AND ( + stream_id < ? OR + (stream_id = ? AND device_id != ?) + ) + """ + count = 0 + for (destination, user_id, stream_id, device_id) in rows: + txn.execute( + delete_sql, (destination, user_id, stream_id, stream_id, device_id) + ) + count += txn.rowcount + + # Since we've deleted unsent deltas, we need to remove the entry + # of last successful sent so that the prev_ids are correctly set. + sql = """ + DELETE FROM device_lists_outbound_last_success + WHERE destination = ? AND user_id = ? + """ + txn.executemany(sql, ((row[0], row[1]) for row in rows)) + + logger.info("Pruned %d device list outbound pokes", count) + + await self.db_pool.runInteraction( + "_prune_old_outbound_device_pokes", _prune_txn, + ) + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -830,14 +1005,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. - self.device_id_exists_cache = Cache( - name="device_id_exists", keylen=2, max_entries=10000 + self.device_id_exists_cache = LruCache( + cache_name="device_id_exists", keylen=2, max_size=10000 ) - self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) - async def store_device( - self, user_id: str, device_id: str, initial_device_display_name: str + self, user_id: str, device_id: str, initial_device_display_name: Optional[str] ) -> bool: """Ensure the given device is known; add it to the store if not @@ -879,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) - self.device_id_exists_cache.prefill(key, True) + self.device_id_exists_cache.set(key, True) return inserted except StoreError: raise @@ -955,7 +1128,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def update_remote_device_list_cache_entry( - self, user_id: str, device_id: str, content: JsonDict, stream_id: int + self, user_id: str, device_id: str, content: JsonDict, stream_id: str ) -> None: """Updates a single device in the cache of a remote user's devicelist. @@ -983,7 +1156,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_id: str, content: JsonDict, - stream_id: int, + stream_id: str, ) -> None: if content.get("deleted"): self.db_pool.simple_delete_txn( @@ -1193,95 +1366,3 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): for device_id in device_ids ], ) - - def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000): - """Delete old entries out of the device_lists_outbound_pokes to ensure - that we don't fill up due to dead servers. - - Normally, we try to send device updates as a delta since a previous known point: - this is done by setting the prev_id in the m.device_list_update EDU. However, - for that to work, we have to have a complete record of each change to - each device, which can add up to quite a lot of data. - - An alternative mechanism is that, if the remote server sees that it has missed - an entry in the stream_id sequence for a given user, it will request a full - list of that user's devices. Hence, we can reduce the amount of data we have to - store (and transmit in some future transaction), by clearing almost everything - for a given destination out of the database, and having the remote server - resync. - - All we need to do is make sure we keep at least one row for each - (user, destination) pair, to remind us to send a m.device_list_update EDU for - that user when the destination comes back. It doesn't matter which device - we keep. - """ - yesterday = self._clock.time_msec() - prune_age - - def _prune_txn(txn): - # look for (user, destination) pairs which have an update older than - # the cutoff. - # - # For each pair, we also need to know the most recent stream_id, and - # an arbitrary device_id at that stream_id. - select_sql = """ - SELECT - dlop1.destination, - dlop1.user_id, - MAX(dlop1.stream_id) AS stream_id, - (SELECT MIN(dlop2.device_id) AS device_id FROM - device_lists_outbound_pokes dlop2 - WHERE dlop2.destination = dlop1.destination AND - dlop2.user_id=dlop1.user_id AND - dlop2.stream_id=MAX(dlop1.stream_id) - ) - FROM device_lists_outbound_pokes dlop1 - GROUP BY destination, user_id - HAVING min(ts) < ? AND count(*) > 1 - """ - - txn.execute(select_sql, (yesterday,)) - rows = txn.fetchall() - - if not rows: - return - - logger.info( - "Pruning old outbound device list updates for %i users/destinations: %s", - len(rows), - shortstr((row[0], row[1]) for row in rows), - ) - - # we want to keep the update with the highest stream_id for each user. - # - # there might be more than one update (with different device_ids) with the - # same stream_id, so we also delete all but one rows with the max stream id. - delete_sql = """ - DELETE FROM device_lists_outbound_pokes - WHERE destination = ? AND user_id = ? AND ( - stream_id < ? OR - (stream_id = ? AND device_id != ?) - ) - """ - count = 0 - for (destination, user_id, stream_id, device_id) in rows: - txn.execute( - delete_sql, (destination, user_id, stream_id, stream_id, device_id) - ) - count += txn.rowcount - - # Since we've deleted unsent deltas, we need to remove the entry - # of last successful sent so that the prev_ids are correctly set. - sql = """ - DELETE FROM device_lists_outbound_last_success - WHERE destination = ? AND user_id = ? - """ - txn.executemany(sql, ((row[0], row[1]) for row in rows)) - - logger.info("Pruned %d device list outbound pokes", count) - - return run_as_background_process( - "prune_old_outbound_device_pokes", - self.db_pool.runInteraction, - "_prune_old_outbound_device_pokes", - _prune_txn, - ) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 22e1ed15d0..4d1b92d1aa 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ from twisted.enterprise.adbapi import Connection from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import make_in_list_sql_clause +from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.storage.types import Cursor from synapse.types import JsonDict from synapse.util import json_encoder @@ -33,6 +33,7 @@ from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.handlers.e2e_keys import SignatureListItem + from synapse.server import HomeServer @attr.s(slots=True) @@ -47,7 +48,20 @@ class DeviceKeyLookupResult: keys = attr.ib(type=Optional[JsonDict]) -class EndToEndKeyWorkerStore(SQLBaseStore): +class EndToEndKeyBackgroundStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + "e2e_cross_signing_keys_idx", + index_name="e2e_cross_signing_keys_stream_idx", + table="e2e_cross_signing_keys", + columns=["stream_id"], + unique=True, + ) + + +class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): async def get_e2e_device_keys_for_federation_query( self, user_id: str ) -> Tuple[int, List[JsonDict]]: @@ -367,6 +381,61 @@ class EndToEndKeyWorkerStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + async def set_e2e_fallback_keys( + self, user_id: str, device_id: str, fallback_keys: JsonDict + ) -> None: + """Set the user's e2e fallback keys. + + Args: + user_id: the user whose keys are being set + device_id: the device whose keys are being set + fallback_keys: the keys to set. This is a map from key ID (which is + of the form "algorithm:id") to key data. + """ + # fallback_keys will usually only have one item in it, so using a for + # loop (as opposed to calling simple_upsert_many_txn) won't be too bad + # FIXME: make sure that only one key per algorithm is uploaded + for key_id, fallback_key in fallback_keys.items(): + algorithm, key_id = key_id.split(":", 1) + await self.db_pool.simple_upsert( + "e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + values={ + "key_id": key_id, + "key_json": json_encoder.encode(fallback_key), + "used": False, + }, + desc="set_e2e_fallback_key", + ) + + await self.invalidate_cache_and_stream( + "get_e2e_unused_fallback_key_types", (user_id, device_id) + ) + + @cached(max_entries=10000) + async def get_e2e_unused_fallback_key_types( + self, user_id: str, device_id: str + ) -> List[str]: + """Returns the fallback key types that have an unused key. + + Args: + user_id: the user whose keys are being queried + device_id: the device whose keys are being queried + + Returns: + a list of key types + """ + return await self.db_pool.simple_select_onecol( + "e2e_fallback_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id, "used": False}, + retcol="algorithm", + desc="get_e2e_unused_fallback_key_types", + ) + async def get_e2e_cross_signing_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None ) -> Optional[dict]: @@ -701,15 +770,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): " WHERE user_id = ? AND device_id = ? AND algorithm = ?" " LIMIT 1" ) + fallback_sql = ( + "SELECT key_id, key_json, used FROM e2e_fallback_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " LIMIT 1" + ) result = {} delete = [] + used_fallbacks = [] for user_id, device_id, algorithm in query_list: user_result = result.setdefault(user_id, {}) device_result = user_result.setdefault(device_id, {}) txn.execute(sql, (user_id, device_id, algorithm)) - for key_id, key_json in txn: + otk_row = txn.fetchone() + if otk_row is not None: + key_id, key_json = otk_row device_result[algorithm + ":" + key_id] = key_json delete.append((user_id, device_id, algorithm, key_id)) + else: + # no one-time key available, so see if there's a fallback + # key + txn.execute(fallback_sql, (user_id, device_id, algorithm)) + fallback_row = txn.fetchone() + if fallback_row is not None: + key_id, key_json, used = fallback_row + device_result[algorithm + ":" + key_id] = key_json + if not used: + used_fallbacks.append( + (user_id, device_id, algorithm, key_id) + ) + + # drop any one-time keys that were claimed sql = ( "DELETE FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ? AND algorithm = ?" @@ -726,6 +817,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + # mark fallback keys as used + for user_id, device_id, algorithm, key_id in used_fallbacks: + self.db_pool.simple_update_txn( + txn, + "e2e_fallback_keys_json", + { + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + }, + {"used": True}, + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + ) + return result return await self.db_pool.runInteraction( @@ -754,6 +862,19 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + self.db_pool.simple_delete_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="e2e_fallback_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + ) await self.db_pool.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 6d3689c09e..2e07c37340 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -19,19 +19,33 @@ from typing import Dict, Iterable, List, Set, Tuple from synapse.api.errors import StoreError from synapse.events import EventBase -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.types import Collection from synapse.util.caches.descriptors import cached +from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + if hs.config.run_background_tasks: + hs.get_clock().looping_call( + self._delete_old_forward_extrem_cache, 60 * 60 * 1000 + ) + + # Cache of event ID to list of auth event IDs and their depths. + self._event_auth_cache = LruCache( + 500000, "_event_auth_cache", size_callback=len + ) # type: LruCache[str, List[Tuple[str, int]]] + async def get_auth_chain( self, event_ids: Collection[str], include_given: bool = False ) -> List[EventBase]: @@ -76,17 +90,45 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas else: results = set() - base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE " + # We pull out the depth simply so that we can populate the + # `_event_auth_cache` cache. + base_sql = """ + SELECT a.event_id, auth_id, depth + FROM event_auth AS a + INNER JOIN events AS e ON (e.event_id = a.auth_id) + WHERE + """ front = set(event_ids) while front: new_front = set() for chunk in batch_iter(front, 100): - clause, args = make_in_list_sql_clause( - txn.database_engine, "event_id", chunk - ) - txn.execute(base_sql + clause, args) - new_front.update(r[0] for r in txn) + # Pull the auth events either from the cache or DB. + to_fetch = [] # Event IDs to fetch from DB # type: List[str] + for event_id in chunk: + res = self._event_auth_cache.get(event_id) + if res is None: + to_fetch.append(event_id) + else: + new_front.update(auth_id for auth_id, depth in res) + + if to_fetch: + clause, args = make_in_list_sql_clause( + txn.database_engine, "a.event_id", to_fetch + ) + txn.execute(base_sql + clause, args) + + # Note we need to batch up the results by event ID before + # adding to the cache. + to_cache = {} + for event_id, auth_event_id, auth_event_depth in txn: + to_cache.setdefault(event_id, []).append( + (auth_event_id, auth_event_depth) + ) + new_front.add(auth_event_id) + + for event_id, auth_events in to_cache.items(): + self._event_auth_cache.set(event_id, auth_events) new_front -= results @@ -205,14 +247,38 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas break # Fetch the auth events and their depths of the N last events we're - # currently walking + # currently walking, either from cache or DB. search, chunk = search[:-100], search[-100:] - clause, args = make_in_list_sql_clause( - txn.database_engine, "a.event_id", [e_id for _, e_id in chunk] - ) - txn.execute(base_sql + clause, args) - for event_id, auth_event_id, auth_event_depth in txn: + found = [] # Results found # type: List[Tuple[str, str, int]] + to_fetch = [] # Event IDs to fetch from DB # type: List[str] + for _, event_id in chunk: + res = self._event_auth_cache.get(event_id) + if res is None: + to_fetch.append(event_id) + else: + found.extend((event_id, auth_id, depth) for auth_id, depth in res) + + if to_fetch: + clause, args = make_in_list_sql_clause( + txn.database_engine, "a.event_id", to_fetch + ) + txn.execute(base_sql + clause, args) + + # We parse the results and add the to the `found` set and the + # cache (note we need to batch up the results by event ID before + # adding to the cache). + to_cache = {} + for event_id, auth_event_id, auth_event_depth in txn: + to_cache.setdefault(event_id, []).append( + (auth_event_id, auth_event_depth) + ) + found.append((event_id, auth_event_id, auth_event_depth)) + + for event_id, auth_events in to_cache.items(): + self._event_auth_cache.set(event_id, auth_events) + + for event_id, auth_event_id, auth_event_depth in found: event_to_auth_events.setdefault(event_id, set()).add(auth_event_id) sets = event_to_missing_sets.get(auth_event_id) @@ -586,6 +652,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return [row["event_id"] for row in rows] + @wrap_as_background_process("delete_old_forward_extrem_cache") + async def _delete_old_forward_extrem_cache(self) -> None: + def _delete_old_forward_extrem_cache_txn(txn): + # Delete entries older than a month, while making sure we don't delete + # the only entries for a room. + sql = """ + DELETE FROM stream_ordering_to_exterm + WHERE + room_id IN ( + SELECT room_id + FROM stream_ordering_to_exterm + WHERE stream_ordering > ? + ) AND stream_ordering < ? + """ + txn.execute( + sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) + ) + + await self.db_pool.runInteraction( + "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn, + ) + class EventFederationStore(EventFederationWorkerStore): """ Responsible for storing and serving up the various graphs associated @@ -606,34 +694,6 @@ class EventFederationStore(EventFederationWorkerStore): self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth ) - hs.get_clock().looping_call( - self._delete_old_forward_extrem_cache, 60 * 60 * 1000 - ) - - def _delete_old_forward_extrem_cache(self): - def _delete_old_forward_extrem_cache_txn(txn): - # Delete entries older than a month, while making sure we don't delete - # the only entries for a room. - sql = """ - DELETE FROM stream_ordering_to_exterm - WHERE - room_id IN ( - SELECT room_id - FROM stream_ordering_to_exterm - WHERE stream_ordering > ? - ) AND stream_ordering < ? - """ - txn.execute( - sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) - ) - - return run_as_background_process( - "delete_old_forward_extrem_cache", - self.db_pool.runInteraction, - "_delete_old_forward_extrem_cache", - _delete_old_forward_extrem_cache_txn, - ) - async def clean_room_for_join(self, room_id): return await self.db_pool.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 62f1738732..2e56dfaf31 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -13,15 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Dict, List, Optional, Tuple, Union import attr -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -74,19 +73,21 @@ class EventPushActionsWorkerStore(SQLBaseStore): self.stream_ordering_month_ago = None self.stream_ordering_day_ago = None - cur = LoggingTransaction( - db_conn.cursor(), - name="_find_stream_orderings_for_times_txn", - database_engine=self.database_engine, - ) + cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn") self._find_stream_orderings_for_times_txn(cur) cur.close() self.find_stream_orderings_looping_call = self._clock.looping_call( self._find_stream_orderings_for_times, 10 * 60 * 1000 ) + self._rotate_delay = 3 self._rotate_count = 10000 + self._doing_notif_rotation = False + if hs.config.run_background_tasks: + self._rotate_notif_loop = self._clock.looping_call( + self._rotate_notifs, 30 * 60 * 1000 + ) @cached(num_args=3, tree=True, max_entries=5000) async def get_unread_event_push_actions_by_room_for_user( @@ -518,15 +519,14 @@ class EventPushActionsWorkerStore(SQLBaseStore): "Error removing push actions after event persistence failure" ) - def _find_stream_orderings_for_times(self): - return run_as_background_process( - "event_push_action_stream_orderings", - self.db_pool.runInteraction, + @wrap_as_background_process("event_push_action_stream_orderings") + async def _find_stream_orderings_for_times(self) -> None: + await self.db_pool.runInteraction( "_find_stream_orderings_for_times", self._find_stream_orderings_for_times_txn, ) - def _find_stream_orderings_for_times_txn(self, txn): + def _find_stream_orderings_for_times_txn(self, txn: LoggingTransaction) -> None: logger.info("Searching for stream ordering 1 month ago") self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 @@ -656,129 +656,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) return result[0] if result else None - -class EventPushActionsStore(EventPushActionsWorkerStore): - EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - - self.db_pool.updates.register_background_index_update( - self.EPA_HIGHLIGHT_INDEX, - index_name="event_push_actions_u_highlight", - table="event_push_actions", - columns=["user_id", "stream_ordering"], - ) - - self.db_pool.updates.register_background_index_update( - "event_push_actions_highlights_index", - index_name="event_push_actions_highlights_index", - table="event_push_actions", - columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], - where_clause="highlight=1", - ) - - self._doing_notif_rotation = False - self._rotate_notif_loop = self._clock.looping_call( - self._start_rotate_notifs, 30 * 60 * 1000 - ) - - async def get_push_actions_for_user( - self, user_id, before=None, limit=50, only_highlight=False - ): - def f(txn): - before_clause = "" - if before: - before_clause = "AND epa.stream_ordering < ?" - args = [user_id, before, limit] - else: - args = [user_id, limit] - - if only_highlight: - if len(before_clause) > 0: - before_clause += " " - before_clause += "AND epa.highlight = 1" - - # NB. This assumes event_ids are globally unique since - # it makes the query easier to index - sql = ( - "SELECT epa.event_id, epa.room_id," - " epa.stream_ordering, epa.topological_ordering," - " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" - " FROM event_push_actions epa, events e" - " WHERE epa.event_id = e.event_id" - " AND epa.user_id = ? %s" - " AND epa.notif = 1" - " ORDER BY epa.stream_ordering DESC" - " LIMIT ?" % (before_clause,) - ) - txn.execute(sql, args) - return self.db_pool.cursor_to_dict(txn) - - push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f) - for pa in push_actions: - pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) - return push_actions - - async def get_latest_push_action_stream_ordering(self): - def f(txn): - txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") - return txn.fetchone() - - result = await self.db_pool.runInteraction( - "get_latest_push_action_stream_ordering", f - ) - return result[0] or 0 - - def _remove_old_push_actions_before_txn( - self, txn, room_id, user_id, stream_ordering - ): - """ - Purges old push actions for a user and room before a given - stream_ordering. - - We however keep a months worth of highlighted notifications, so that - users can still get a list of recent highlights. - - Args: - txn: The transcation - room_id: Room ID to delete from - user_id: user ID to delete for - stream_ordering: The lowest stream ordering which will - not be deleted. - """ - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id, user_id), - ) - - # We need to join on the events table to get the received_ts for - # event_push_actions and sqlite won't let us use a join in a delete so - # we can't just delete where received_ts < x. Furthermore we can - # only identify event_push_actions by a tuple of room_id, event_id - # we we can't use a subquery. - # Instead, we look up the stream ordering for the last event in that - # room received before the threshold time and delete event_push_actions - # in the room with a stream_odering before that. - txn.execute( - "DELETE FROM event_push_actions " - " WHERE user_id = ? AND room_id = ? AND " - " stream_ordering <= ?" - " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", - (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), - ) - - txn.execute( - """ - DELETE FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? - """, - (room_id, user_id, stream_ordering), - ) - - def _start_rotate_notifs(self): - return run_as_background_process("rotate_notifs", self._rotate_notifs) - + @wrap_as_background_process("rotate_notifs") async def _rotate_notifs(self): if self._doing_notif_rotation or self.stream_ordering_day_ago is None: return @@ -958,6 +836,121 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) +class EventPushActionsStore(EventPushActionsWorkerStore): + EPA_HIGHLIGHT_INDEX = "epa_highlight_index" + + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + self.EPA_HIGHLIGHT_INDEX, + index_name="event_push_actions_u_highlight", + table="event_push_actions", + columns=["user_id", "stream_ordering"], + ) + + self.db_pool.updates.register_background_index_update( + "event_push_actions_highlights_index", + index_name="event_push_actions_highlights_index", + table="event_push_actions", + columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], + where_clause="highlight=1", + ) + + async def get_push_actions_for_user( + self, user_id, before=None, limit=50, only_highlight=False + ): + def f(txn): + before_clause = "" + if before: + before_clause = "AND epa.stream_ordering < ?" + args = [user_id, before, limit] + else: + args = [user_id, limit] + + if only_highlight: + if len(before_clause) > 0: + before_clause += " " + before_clause += "AND epa.highlight = 1" + + # NB. This assumes event_ids are globally unique since + # it makes the query easier to index + sql = ( + "SELECT epa.event_id, epa.room_id," + " epa.stream_ordering, epa.topological_ordering," + " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" + " FROM event_push_actions epa, events e" + " WHERE epa.event_id = e.event_id" + " AND epa.user_id = ? %s" + " AND epa.notif = 1" + " ORDER BY epa.stream_ordering DESC" + " LIMIT ?" % (before_clause,) + ) + txn.execute(sql, args) + return self.db_pool.cursor_to_dict(txn) + + push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f) + for pa in push_actions: + pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) + return push_actions + + async def get_latest_push_action_stream_ordering(self): + def f(txn): + txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") + return txn.fetchone() + + result = await self.db_pool.runInteraction( + "get_latest_push_action_stream_ordering", f + ) + return result[0] or 0 + + def _remove_old_push_actions_before_txn( + self, txn, room_id, user_id, stream_ordering + ): + """ + Purges old push actions for a user and room before a given + stream_ordering. + + We however keep a months worth of highlighted notifications, so that + users can still get a list of recent highlights. + + Args: + txn: The transcation + room_id: Room ID to delete from + user_id: user ID to delete for + stream_ordering: The lowest stream ordering which will + not be deleted. + """ + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (room_id, user_id), + ) + + # We need to join on the events table to get the received_ts for + # event_push_actions and sqlite won't let us use a join in a delete so + # we can't just delete where received_ts < x. Furthermore we can + # only identify event_push_actions by a tuple of room_id, event_id + # we we can't use a subquery. + # Instead, we look up the stream ordering for the last event in that + # room received before the threshold time and delete event_push_actions + # in the room with a stream_odering before that. + txn.execute( + "DELETE FROM event_push_actions " + " WHERE user_id = ? AND room_id = ? AND " + " stream_ordering <= ?" + " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", + (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), + ) + + txn.execute( + """ + DELETE FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? + """, + (room_id, user_id, stream_ordering), + ) + + def _action_has_highlight(actions): for action in actions: try: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 18def01f50..90fb1a1f00 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -34,7 +34,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import StateMap, get_domain_from_id -from synapse.util.frozenutils import frozendict_json_encoder +from synapse.util import json_encoder from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -52,16 +52,6 @@ event_counter = Counter( ) -def encode_json(json_object): - """ - Encode a Python object as JSON and return it in a Unicode string. - """ - out = frozendict_json_encoder.encode(json_object) - if isinstance(out, bytes): - out = out.decode("utf8") - return out - - _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) @@ -341,6 +331,10 @@ class PersistEventsStore: min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + # stream orderings should have been assigned by now + assert min_stream_order + assert max_stream_order + self._update_forward_extremities_txn( txn, new_forward_extremities=new_forward_extremeties, @@ -367,6 +361,8 @@ class PersistEventsStore: self._store_event_txn(txn, events_and_contexts=events_and_contexts) + self._persist_transaction_ids_txn(txn, events_and_contexts) + # Insert into event_to_state_groups. self._store_event_state_mappings_txn(txn, events_and_contexts) @@ -411,6 +407,35 @@ class PersistEventsStore: # room_memberships, where applicable. self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + def _persist_transaction_ids_txn( + self, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + ): + """Persist the mapping from transaction IDs to event IDs (if defined). + """ + + to_insert = [] + for event, _ in events_and_contexts: + token_id = getattr(event.internal_metadata, "token_id", None) + txn_id = getattr(event.internal_metadata, "txn_id", None) + if token_id and txn_id: + to_insert.append( + { + "event_id": event.event_id, + "room_id": event.room_id, + "user_id": event.sender, + "token_id": token_id, + "txn_id": txn_id, + "inserted_ts": self._clock.time_msec(), + } + ) + + if to_insert: + self.db_pool.simple_insert_many_txn( + txn, table="event_txn_id", values=to_insert, + ) + def _update_current_state_txn( self, txn: LoggingTransaction, @@ -432,12 +457,12 @@ class PersistEventsStore: # so that async background tasks get told what happened. sql = """ INSERT INTO current_state_delta_stream - (stream_id, room_id, type, state_key, event_id, prev_event_id) - SELECT ?, room_id, type, state_key, null, event_id + (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id) + SELECT ?, ?, room_id, type, state_key, null, event_id FROM current_state_events WHERE room_id = ? """ - txn.execute(sql, (stream_id, room_id)) + txn.execute(sql, (stream_id, self._instance_name, room_id)) self.db_pool.simple_delete_txn( txn, table="current_state_events", keyvalues={"room_id": room_id}, @@ -458,8 +483,8 @@ class PersistEventsStore: # sql = """ INSERT INTO current_state_delta_stream - (stream_id, room_id, type, state_key, event_id, prev_event_id) - SELECT ?, ?, ?, ?, ?, ( + (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id) + SELECT ?, ?, ?, ?, ?, ?, ( SELECT event_id FROM current_state_events WHERE room_id = ? AND type = ? AND state_key = ? ) @@ -469,6 +494,7 @@ class PersistEventsStore: ( ( stream_id, + self._instance_name, room_id, etype, state_key, @@ -743,7 +769,7 @@ class PersistEventsStore: logger.exception("") raise - metadata_json = encode_json(event.internal_metadata.get_dict()) + metadata_json = json_encoder.encode(event.internal_metadata.get_dict()) sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" txn.execute(sql, (metadata_json, event.event_id)) @@ -759,6 +785,7 @@ class PersistEventsStore: "event_stream_ordering": stream_order, "event_id": event.event_id, "state_group": state_group_id, + "instance_name": self._instance_name, }, ) @@ -797,10 +824,10 @@ class PersistEventsStore: { "event_id": event.event_id, "room_id": event.room_id, - "internal_metadata": encode_json( + "internal_metadata": json_encoder.encode( event.internal_metadata.get_dict() ), - "json": encode_json(event_dict(event)), + "json": json_encoder.encode(event_dict(event)), "format_version": event.format_version, } for event, _ in events_and_contexts @@ -1022,9 +1049,7 @@ class PersistEventsStore: def prefill(): for cache_entry in to_prefill: - self.store._get_event_cache.prefill( - (cache_entry[0].event_id,), cache_entry - ) + self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry) txn.call_after(prefill) @@ -1241,6 +1266,10 @@ class PersistEventsStore: ) def _store_retention_policy_for_room_txn(self, txn, event): + if not event.is_state(): + logger.debug("Ignoring non-state m.room.retention event") + return + if hasattr(event, "content") and ( "min_lifetime" in event.content or "max_lifetime" in event.content ): diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 5e4af2eb51..97b6754846 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -92,6 +92,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): where_clause="NOT have_censored", ) + self.db_pool.updates.register_background_index_update( + "users_have_local_media", + index_name="users_have_local_media", + table="local_media_repository", + columns=["user_id", "created_ts"], + ) + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index f95679ebc4..4732685f6e 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import itertools import logging import threading @@ -32,9 +31,13 @@ from synapse.api.room_versions import ( RoomVersions, ) from synapse.events import EventBase, make_event_from_dict +from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream @@ -42,8 +45,9 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator -from synapse.types import Collection, get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached +from synapse.types import Collection, JsonDict, get_domain_from_id +from synapse.util.caches.descriptors import cached +from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -74,6 +78,13 @@ class EventRedactBehaviour(Names): class EventsWorkerStore(SQLBaseStore): + # Whether to use dedicated DB threads for event fetching. This is only used + # if there are multiple DB threads available. When used will lock the DB + # thread for periods of time (so unit tests want to disable this when they + # run DB transactions on the main thread). See EVENT_QUEUE_* for more + # options controlling this. + USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True + def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -130,11 +141,16 @@ class EventsWorkerStore(SQLBaseStore): db_conn, "events", "stream_ordering", step=-1 ) - self._get_event_cache = Cache( - "*getEvent*", + if hs.config.run_background_tasks: + # We periodically clean out old transaction ID mappings + self._clock.looping_call( + self._cleanup_old_transaction_ids, 5 * 60 * 1000, + ) + + self._get_event_cache = LruCache( + cache_name="*getEvent*", keylen=3, - max_entries=hs.config.caches.event_cache_size, - apply_cache_factor_from_config=False, + max_size=hs.config.caches.event_cache_size, ) self._event_fetch_lock = threading.Condition() @@ -510,6 +526,57 @@ class EventsWorkerStore(SQLBaseStore): return event_map + async def get_stripped_room_state_from_event_context( + self, + context: EventContext, + state_types_to_include: List[EventTypes], + membership_user_id: Optional[str] = None, + ) -> List[JsonDict]: + """ + Retrieve the stripped state from a room, given an event context to retrieve state + from as well as the state types to include. Optionally, include the membership + events from a specific user. + + "Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys + are included from each state event. + + Args: + context: The event context to retrieve state of the room from. + state_types_to_include: The type of state events to include. + membership_user_id: An optional user ID to include the stripped membership state + events of. This is useful when generating the stripped state of a room for + invites. We want to send membership events of the inviter, so that the + invitee can display the inviter's profile information if the room lacks any. + + Returns: + A list of dictionaries, each representing a stripped state event from the room. + """ + current_state_ids = await context.get_current_state_ids() + + # We know this event is not an outlier, so this must be + # non-None. + assert current_state_ids is not None + + # The state to include + state_to_include_ids = [ + e_id + for k, e_id in current_state_ids.items() + if k[0] in state_types_to_include + or (membership_user_id and k == (EventTypes.Member, membership_user_id)) + ] + + state_to_include = await self.get_events(state_to_include_ids) + + return [ + { + "type": e.type, + "state_key": e.state_key, + "content": e.content, + "sender": e.sender, + } + for e in state_to_include.values() + ] + def _do_fetch(self, conn): """Takes a database connection and waits for requests for events from the _event_fetch_list queue. @@ -522,7 +589,11 @@ class EventsWorkerStore(SQLBaseStore): if not event_list: single_threaded = self.database_engine.single_threaded - if single_threaded or i > EVENT_QUEUE_ITERATIONS: + if ( + not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING + or single_threaded + or i > EVENT_QUEUE_ITERATIONS + ): self._event_fetch_ongoing -= 1 return else: @@ -712,6 +783,7 @@ class EventsWorkerStore(SQLBaseStore): internal_metadata_dict=internal_metadata, rejected_reason=rejected_reason, ) + original_ev.internal_metadata.stream_ordering = row["stream_ordering"] event_map[event_id] = original_ev @@ -728,7 +800,7 @@ class EventsWorkerStore(SQLBaseStore): event=original_ev, redacted_event=redacted_event ) - self._get_event_cache.prefill((event_id,), cache_entry) + self._get_event_cache.set((event_id,), cache_entry) result_map[event_id] = cache_entry return result_map @@ -779,6 +851,8 @@ class EventsWorkerStore(SQLBaseStore): * event_id (str) + * stream_ordering (int): stream ordering for this event + * json (str): json-encoded event structure * internal_metadata (str): json-encoded internal metadata dict @@ -811,13 +885,15 @@ class EventsWorkerStore(SQLBaseStore): sql = """\ SELECT e.event_id, - e.internal_metadata, - e.json, - e.format_version, + e.stream_ordering, + ej.internal_metadata, + ej.json, + ej.format_version, r.room_version, rej.reason - FROM event_json as e - LEFT JOIN rooms r USING (room_id) + FROM events AS e + JOIN event_json AS ej USING (event_id) + LEFT JOIN rooms r ON r.room_id = e.room_id LEFT JOIN rejections as rej USING (event_id) WHERE """ @@ -831,11 +907,12 @@ class EventsWorkerStore(SQLBaseStore): event_id = row[0] event_dict[event_id] = { "event_id": event_id, - "internal_metadata": row[1], - "json": row[2], - "format_version": row[3], - "room_version_id": row[4], - "rejected_reason": row[5], + "stream_ordering": row[1], + "internal_metadata": row[2], + "json": row[3], + "format_version": row[4], + "room_version_id": row[5], + "rejected_reason": row[6], "redactions": [], } @@ -1017,16 +1094,12 @@ class EventsWorkerStore(SQLBaseStore): return {"v1": complexity_v1} - def get_current_backfill_token(self): - """The current minimum token that backfilled events have reached""" - return -self._backfill_id_gen.get_current_token() - def get_current_events_token(self): """The current maximum token that events have reached""" return self._stream_id_gen.get_current_token() async def get_all_new_forward_event_rows( - self, last_id: int, current_id: int, limit: int + self, instance_name: str, last_id: int, current_id: int, limit: int ) -> List[Tuple]: """Returns new events, for the Events replication stream @@ -1044,16 +1117,19 @@ class EventsWorkerStore(SQLBaseStore): def get_all_new_forward_event_rows(txn): sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" + " LEFT JOIN room_memberships USING (event_id)" + " LEFT JOIN rejections USING (event_id)" " WHERE ? < stream_ordering AND stream_ordering <= ?" + " AND instance_name = ?" " ORDER BY stream_ordering ASC" " LIMIT ?" ) - txn.execute(sql, (last_id, current_id, limit)) + txn.execute(sql, (last_id, current_id, instance_name, limit)) return txn.fetchall() return await self.db_pool.runInteraction( @@ -1061,7 +1137,7 @@ class EventsWorkerStore(SQLBaseStore): ) async def get_ex_outlier_stream_rows( - self, last_id: int, current_id: int + self, instance_name: str, last_id: int, current_id: int ) -> List[Tuple]: """Returns de-outliered events, for the Events replication stream @@ -1078,18 +1154,21 @@ class EventsWorkerStore(SQLBaseStore): def get_ex_outlier_stream_rows_txn(txn): sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" + " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" + " LEFT JOIN room_memberships USING (event_id)" + " LEFT JOIN rejections USING (event_id)" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" + " AND out.instance_name = ?" " ORDER BY event_stream_ordering ASC" ) - txn.execute(sql, (last_id, current_id)) + txn.execute(sql, (last_id, current_id, instance_name)) return txn.fetchall() return await self.db_pool.runInteraction( @@ -1102,6 +1181,9 @@ class EventsWorkerStore(SQLBaseStore): """Get updates for backfill replication stream, including all new backfilled events and events that have gone from being outliers to not. + NOTE: The IDs given here are from replication, and so should be + *positive*. + Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. @@ -1132,10 +1214,11 @@ class EventsWorkerStore(SQLBaseStore): " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > stream_ordering AND stream_ordering >= ?" + " AND instance_name = ?" " ORDER BY stream_ordering ASC" " LIMIT ?" ) - txn.execute(sql, (-last_id, -current_id, limit)) + txn.execute(sql, (-last_id, -current_id, instance_name, limit)) new_event_updates = [(row[0], row[1:]) for row in txn] limited = False @@ -1149,15 +1232,16 @@ class EventsWorkerStore(SQLBaseStore): "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," " state_key, redacts, relates_to_id" " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" + " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" + " AND out.instance_name = ?" " ORDER BY event_stream_ordering DESC" ) - txn.execute(sql, (-last_id, -upper_bound)) + txn.execute(sql, (-last_id, -upper_bound, instance_name)) new_event_updates.extend((row[0], row[1:]) for row in txn) if len(new_event_updates) >= limit: @@ -1171,7 +1255,7 @@ class EventsWorkerStore(SQLBaseStore): ) async def get_all_updated_current_state_deltas( - self, from_token: int, to_token: int, target_row_count: int + self, instance_name: str, from_token: int, to_token: int, target_row_count: int ) -> Tuple[List[Tuple], int, bool]: """Fetch updates from current_state_delta_stream @@ -1197,9 +1281,10 @@ class EventsWorkerStore(SQLBaseStore): SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream WHERE ? < stream_id AND stream_id <= ? + AND instance_name = ? ORDER BY stream_id ASC LIMIT ? """ - txn.execute(sql, (from_token, to_token, target_row_count)) + txn.execute(sql, (from_token, to_token, instance_name, target_row_count)) return txn.fetchall() def get_deltas_for_stream_id_txn(txn, stream_id): @@ -1287,3 +1372,77 @@ class EventsWorkerStore(SQLBaseStore): return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) + + async def get_event_id_from_transaction_id( + self, room_id: str, user_id: str, token_id: int, txn_id: str + ) -> Optional[str]: + """Look up if we have already persisted an event for the transaction ID, + returning the event ID if so. + """ + return await self.db_pool.simple_select_one_onecol( + table="event_txn_id", + keyvalues={ + "room_id": room_id, + "user_id": user_id, + "token_id": token_id, + "txn_id": txn_id, + }, + retcol="event_id", + allow_none=True, + desc="get_event_id_from_transaction_id", + ) + + async def get_already_persisted_events( + self, events: Iterable[EventBase] + ) -> Dict[str, str]: + """Look up if we have already persisted an event for the transaction ID, + returning a mapping from event ID in the given list to the event ID of + an existing event. + + Also checks if there are duplicates in the given events, if there are + will map duplicates to the *first* event. + """ + + mapping = {} + txn_id_to_event = {} # type: Dict[Tuple[str, int, str], str] + + for event in events: + token_id = getattr(event.internal_metadata, "token_id", None) + txn_id = getattr(event.internal_metadata, "txn_id", None) + + if token_id and txn_id: + # Check if this is a duplicate of an event in the given events. + existing = txn_id_to_event.get((event.room_id, token_id, txn_id)) + if existing: + mapping[event.event_id] = existing + continue + + # Check if this is a duplicate of an event we've already + # persisted. + existing = await self.get_event_id_from_transaction_id( + event.room_id, event.sender, token_id, txn_id + ) + if existing: + mapping[event.event_id] = existing + txn_id_to_event[(event.room_id, token_id, txn_id)] = existing + else: + txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id + + return mapping + + @wrap_as_background_process("_cleanup_old_transaction_ids") + async def _cleanup_old_transaction_ids(self): + """Cleans out transaction id mappings older than 24hrs. + """ + + def _cleanup_old_transaction_ids_txn(txn): + sql = """ + DELETE FROM event_txn_id + WHERE inserted_ts < ? + """ + one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000 + txn.execute(sql, (one_day_ago,)) + + return await self.db_pool.runInteraction( + "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn, + ) diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index ad43bb05ab..f8f4bb9b3f 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -122,9 +122,7 @@ class KeyStore(SQLBaseStore): # param, which is itself the 2-tuple (server_name, key_id). invalidations.append((server_name, key_id)) - await self.db_pool.runInteraction( - "store_server_verify_keys", - self.db_pool.simple_upsert_many_txn, + await self.db_pool.simple_upsert_many( table="server_signature_keys", key_names=("server_name", "key_id"), key_values=key_values, @@ -135,6 +133,7 @@ class KeyStore(SQLBaseStore): "verify_key", ), value_values=value_values, + desc="store_server_verify_keys", ) invalidate = self._get_server_verify_key.invalidate diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index cc538c5c10..4b2f224718 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -93,6 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) + self.server_name = hs.hostname async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media @@ -115,6 +116,109 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_local_media", ) + async def get_local_media_by_user_paginate( + self, start: int, limit: int, user_id: str + ) -> Tuple[List[Dict[str, Any]], int]: + """Get a paginated list of metadata for a local piece of media + which an user_id has uploaded + + Args: + start: offset in the list + limit: maximum amount of media_ids to retrieve + user_id: fully-qualified user id + Returns: + A paginated list of all metadata of user's media, + plus the total count of all the user's media + """ + + def get_local_media_by_user_paginate_txn(txn): + + args = [user_id] + sql = """ + SELECT COUNT(*) as total_media + FROM local_media_repository + WHERE user_id = ? + """ + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = """ + SELECT + "media_id", + "media_type", + "media_length", + "upload_name", + "created_ts", + "last_access_ts", + "quarantined_by", + "safe_from_quarantine" + FROM local_media_repository + WHERE user_id = ? + ORDER BY created_ts DESC, media_id DESC + LIMIT ? OFFSET ? + """ + + args += [limit, start] + txn.execute(sql, args) + media = self.db_pool.cursor_to_dict(txn) + return media, count + + return await self.db_pool.runInteraction( + "get_local_media_by_user_paginate_txn", get_local_media_by_user_paginate_txn + ) + + async def get_local_media_before( + self, before_ts: int, size_gt: int, keep_profiles: bool, + ) -> Optional[List[str]]: + + # to find files that have never been accessed (last_access_ts IS NULL) + # compare with `created_ts` + sql = """ + SELECT media_id + FROM local_media_repository AS lmr + WHERE + ( last_access_ts < ? + OR ( created_ts < ? AND last_access_ts IS NULL ) ) + AND media_length > ? + """ + + if keep_profiles: + sql_keep = """ + AND ( + NOT EXISTS + (SELECT 1 + FROM profiles + WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id) + AND NOT EXISTS + (SELECT 1 + FROM groups + WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id) + AND NOT EXISTS + (SELECT 1 + FROM room_memberships + WHERE room_memberships.avatar_url = '{media_prefix}' || lmr.media_id) + AND NOT EXISTS + (SELECT 1 + FROM user_directory + WHERE user_directory.avatar_url = '{media_prefix}' || lmr.media_id) + AND NOT EXISTS + (SELECT 1 + FROM room_stats_state + WHERE room_stats_state.avatar = '{media_prefix}' || lmr.media_id) + ) + """.format( + media_prefix="mxc://%s/" % (self.server_name,), + ) + sql += sql_keep + + def _get_local_media_before_txn(txn): + txn.execute(sql, (before_ts, before_ts, size_gt)) + return [row[0] for row in txn] + + return await self.db_pool.runInteraction( + "get_local_media_before", _get_local_media_before_txn + ) + async def store_local_media( self, media_id, @@ -348,6 +452,33 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_remote_media_thumbnails", ) + async def get_remote_media_thumbnail( + self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str, + ) -> Optional[Dict[str, Any]]: + """Fetch the thumbnail info of given width, height and type. + """ + + return await self.db_pool.simple_select_one( + table="remote_media_cache_thumbnails", + keyvalues={ + "media_origin": origin, + "media_id": media_id, + "thumbnail_width": t_width, + "thumbnail_height": t_height, + "thumbnail_type": t_type, + }, + retcols=( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + "filesystem_id", + ), + allow_none=True, + desc="get_remote_media_thumbnail", + ) + async def store_remote_media_thumbnail( self, origin, diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 92099f95ce..ab18cc4d79 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -12,15 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import calendar +import logging +import time +from typing import Dict from synapse.metrics import GaugeBucketCollector -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) +logger = logging.getLogger(__name__) + # Collect metrics on the number of forward extremities that exist. _extremities_collecter = GaugeBucketCollector( "synapse_forward_extremities", @@ -51,15 +57,13 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): super().__init__(database, db_conn, hs) # Read the extrems every 60 minutes - def read_forward_extremities(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "read_forward_extremities", self._read_forward_extremities - ) + if hs.config.run_background_tasks: + self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000) - hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) + # Used in _generate_user_daily_visits to keep track of progress + self._last_user_visit_update = self._get_start_of_day() + @wrap_as_background_process("read_forward_extremities") async def _read_forward_extremities(self): def fetch(txn): txn.execute( @@ -137,3 +141,196 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): return count return await self.db_pool.runInteraction("count_daily_active_rooms", _count) + + async def count_daily_users(self) -> int: + """ + Counts the number of users who used this homeserver in the last 24 hours. + """ + yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) + return await self.db_pool.runInteraction( + "count_daily_users", self._count_users, yesterday + ) + + async def count_monthly_users(self) -> int: + """ + Counts the number of users who used this homeserver in the last 30 days. + Note this method is intended for phonehome metrics only and is different + from the mau figure in synapse.storage.monthly_active_users which, + amongst other things, includes a 3 day grace period before a user counts. + """ + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + return await self.db_pool.runInteraction( + "count_monthly_users", self._count_users, thirty_days_ago + ) + + def _count_users(self, txn, time_from): + """ + Returns number of users seen in the past time_from period + """ + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM user_ips + WHERE last_seen > ? + GROUP BY user_id + ) u + """ + txn.execute(sql, (time_from,)) + (count,) = txn.fetchone() + return count + + async def count_r30_users(self) -> Dict[str, int]: + """ + Counts the number of 30 day retained users, defined as:- + * Users who have created their accounts more than 30 days ago + * Where last seen at most 30 days ago + * Where account creation and last_seen are > 30 days apart + + Returns: + A mapping of counts globally as well as broken out by platform. + """ + + def _count_r30_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + thirty_days_ago_in_secs = now - thirty_days_in_secs + + sql = """ + SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT + users.name, platform, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM user_ips + ) uip + ON users.name = uip.user_id + AND users.appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, platform, users.creation_ts + ) u GROUP BY platform + """ + + results = {} + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + for row in txn: + if row[0] == "unknown": + pass + results[row[0]] = row[1] + + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT users.name, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen + FROM user_ips + ) uip + ON users.name = uip.user_id + AND appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, users.creation_ts + ) u + """ + + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + (count,) = txn.fetchone() + results["all"] = count + + return results + + return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) + + def _get_start_of_day(self): + """ + Returns millisecond unixtime for start of UTC day. + """ + now = time.gmtime() + today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) + return today_start * 1000 + + @wrap_as_background_process("generate_user_daily_visits") + async def generate_user_daily_visits(self) -> None: + """ + Generates daily visit data for use in cohort/ retention analysis + """ + + def _generate_user_daily_visits(txn): + logger.info("Calling _generate_user_daily_visits") + today_start = self._get_start_of_day() + a_day_in_milliseconds = 24 * 60 * 60 * 1000 + now = self._clock.time_msec() + + # A note on user_agent. Technically a given device can have multiple + # user agents, so we need to decide which one to pick. We could have + # handled this in number of ways, but given that we don't care + # _that_ much we have gone for MAX(). For more details of the other + # options considered see + # https://github.com/matrix-org/synapse/pull/8503#discussion_r502306111 + sql = """ + INSERT INTO user_daily_visits (user_id, device_id, timestamp, user_agent) + SELECT u.user_id, u.device_id, ?, MAX(u.user_agent) + FROM user_ips AS u + LEFT JOIN ( + SELECT user_id, device_id, timestamp FROM user_daily_visits + WHERE timestamp = ? + ) udv + ON u.user_id = udv.user_id AND u.device_id=udv.device_id + INNER JOIN users ON users.name=u.user_id + WHERE last_seen > ? AND last_seen <= ? + AND udv.timestamp IS NULL AND users.is_guest=0 + AND users.appservice_id IS NULL + GROUP BY u.user_id, u.device_id + """ + + # This means that the day has rolled over but there could still + # be entries from the previous day. There is an edge case + # where if the user logs in at 23:59 and overwrites their + # last_seen at 00:01 then they will not be counted in the + # previous day's stats - it is important that the query is run + # often to minimise this case. + if today_start > self._last_user_visit_update: + yesterday_start = today_start - a_day_in_milliseconds + txn.execute( + sql, + ( + yesterday_start, + yesterday_start, + self._last_user_visit_update, + today_start, + ), + ) + self._last_user_visit_update = today_start + + txn.execute( + sql, (today_start, today_start, self._last_user_visit_update, now) + ) + # Update _last_user_visit_update to now. The reason to do this + # rather just clamping to the beginning of the day is to limit + # the size of the join - meaning that the query can be run more + # frequently + self._last_user_visit_update = now + + await self.db_pool.runInteraction( + "generate_user_daily_visits", _generate_user_daily_visits + ) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e93aad33cd..d788dc0fc6 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -15,6 +15,7 @@ import logging from typing import Dict, List +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.util.caches.descriptors import cached @@ -32,6 +33,9 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs + self._limit_usage_by_mau = hs.config.limit_usage_by_mau + self._max_mau_value = hs.config.max_mau_value + @cached(num_args=0) async def get_monthly_active_count(self) -> int: """Generates current count of monthly active users @@ -124,60 +128,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): desc="user_last_seen_monthly_active", ) - -class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - - self._limit_usage_by_mau = hs.config.limit_usage_by_mau - self._mau_stats_only = hs.config.mau_stats_only - self._max_mau_value = hs.config.max_mau_value - - # Do not add more reserved users than the total allowable number - # cur = LoggingTransaction( - self.db_pool.new_transaction( - db_conn, - "initialise_mau_threepids", - [], - [], - self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[: self._max_mau_value], - ) - - def _initialise_reserved_users(self, txn, threepids): - """Ensures that reserved threepids are accounted for in the MAU table, should - be called on start up. - - Args: - txn (cursor): - threepids (list[dict]): List of threepid dicts to reserve - """ - - # XXX what is this function trying to achieve? It upserts into - # monthly_active_users for each *registered* reserved mau user, but why? - # - # - shouldn't there already be an entry for each reserved user (at least - # if they have been active recently)? - # - # - if it's important that the timestamp is kept up to date, why do we only - # run this at startup? - - for tp in threepids: - user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) - - if user_id: - is_support = self.is_support_user_txn(txn, user_id) - if not is_support: - # We do this manually here to avoid hitting #6791 - self.db_pool.simple_upsert_txn( - txn, - table="monthly_active_users", - keyvalues={"user_id": user_id}, - values={"timestamp": int(self._clock.time_msec())}, - ) - else: - logger.warning("mau limit reserved threepid %s not found in db" % tp) - + @wrap_as_background_process("reap_monthly_active_users") async def reap_monthly_active_users(self): """Cleans out monthly active user table to ensure that no stale entries exist. @@ -257,6 +208,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): "reap_monthly_active_users", _reap_users, reserved_users ) + +class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self._mau_stats_only = hs.config.mau_stats_only + + # Do not add more reserved users than the total allowable number + self.db_pool.new_transaction( + db_conn, + "initialise_mau_threepids", + [], + [], + self._initialise_reserved_users, + hs.config.mau_limits_reserved_threepids[: self._max_mau_value], + ) + + def _initialise_reserved_users(self, txn, threepids): + """Ensures that reserved threepids are accounted for in the MAU table, should + be called on start up. + + Args: + txn (cursor): + threepids (list[dict]): List of threepid dicts to reserve + """ + + # XXX what is this function trying to achieve? It upserts into + # monthly_active_users for each *registered* reserved mau user, but why? + # + # - shouldn't there already be an entry for each reserved user (at least + # if they have been active recently)? + # + # - if it's important that the timestamp is kept up to date, why do we only + # run this at startup? + + for tp in threepids: + user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) + + if user_id: + is_support = self.is_support_user_txn(txn, user_id) + if not is_support: + # We do this manually here to avoid hitting #6791 + self.db_pool.simple_upsert_txn( + txn, + table="monthly_active_users", + keyvalues={"user_id": user_id}, + values={"timestamp": int(self._clock.time_msec())}, + ) + else: + logger.warning("mau limit reserved threepid %s not found in db" % tp) + async def upsert_monthly_active_user(self, user_id: str) -> None: """Updates or inserts the user into the monthly active user table, which is used to track the current MAU usage of the server diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index d2e0685e9e..0e25ca3d7a 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore @@ -39,7 +39,7 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) - async def get_profile_displayname(self, user_localpart: str) -> str: + async def get_profile_displayname(self, user_localpart: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, @@ -47,7 +47,7 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_displayname", ) - async def get_profile_avatar_url(self, user_localpart: str) -> str: + async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, @@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore): ) async def set_profile_displayname( - self, user_localpart: str, new_displayname: str + self, user_localpart: str, new_displayname: Optional[str] ) -> None: await self.db_pool.simple_update_one( table="profiles", @@ -91,27 +91,6 @@ class ProfileWorkerStore(SQLBaseStore): desc="set_profile_avatar_url", ) - -class ProfileStore(ProfileWorkerStore): - async def add_remote_profile_cache( - self, user_id: str, displayname: str, avatar_url: str - ) -> None: - """Ensure we are caching the remote user's profiles. - - This should only be called when `is_subscribed_remote_profile_for_user` - would return true for the user. - """ - await self.db_pool.simple_upsert( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - values={ - "displayname": displayname, - "avatar_url": avatar_url, - "last_check": self._clock.time_msec(), - }, - desc="add_remote_profile_cache", - ) - async def update_remote_profile_cache( self, user_id: str, displayname: str, avatar_url: str ) -> int: @@ -138,9 +117,34 @@ class ProfileStore(ProfileWorkerStore): desc="delete_remote_profile_cache", ) + async def is_subscribed_remote_profile_for_user(self, user_id): + """Check whether we are interested in a remote user's profile. + """ + res = await self.db_pool.simple_select_one_onecol( + table="group_users", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + return True + + res = await self.db_pool.simple_select_one_onecol( + table="group_invites", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + return True + async def get_remote_profile_cache_entries_that_expire( self, last_checked: int - ) -> Dict[str, str]: + ) -> List[Dict[str, str]]: """Get all users who haven't been checked since `last_checked` """ @@ -160,27 +164,23 @@ class ProfileStore(ProfileWorkerStore): _get_remote_profile_cache_entries_that_expire_txn, ) - async def is_subscribed_remote_profile_for_user(self, user_id): - """Check whether we are interested in a remote user's profile. - """ - res = await self.db_pool.simple_select_one_onecol( - table="group_users", - keyvalues={"user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="should_update_remote_profile_cache_for_user", - ) - if res: - return True +class ProfileStore(ProfileWorkerStore): + async def add_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> None: + """Ensure we are caching the remote user's profiles. - res = await self.db_pool.simple_select_one_onecol( - table="group_invites", + This should only be called when `is_subscribed_remote_profile_for_user` + would return true for the user. + """ + await self.db_pool.simple_upsert( + table="remote_profile_cache", keyvalues={"user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="should_update_remote_profile_cache_for_user", + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="add_remote_profile_cache", ) - - if res: - return True diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index ecfc6717b3..5d668aadb2 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -314,6 +314,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): for table in ( "event_auth", "event_edges", + "event_json", "event_push_actions_staging", "event_reference_hashes", "event_relations", @@ -340,7 +341,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): "destination_rooms", "event_backward_extremities", "event_forward_extremities", - "event_json", "event_push_actions", "event_search", "events", diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index df8609b97b..7997242d90 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -303,7 +303,7 @@ class PusherStore(PusherWorkerStore): lock=False, ) - user_has_pusher = self.get_if_user_has_pusher.cache.get( + user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate( (user_id,), None, update_metrics=False ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index c79ddff680..1e7949a323 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -23,8 +23,8 @@ from twisted.internet import defer from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import JsonDict from synapse.util import json_encoder -from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -274,6 +274,65 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): } return results + @cached(num_args=2,) + async def get_linearized_receipts_for_all_rooms( + self, to_key: int, from_key: Optional[int] = None + ) -> Dict[str, JsonDict]: + """Get receipts for all rooms between two stream_ids, up + to a limit of the latest 100 read receipts. + + Args: + to_key: Max stream id to fetch receipts upto. + from_key: Min stream id to fetch receipts from. None fetches + from the start. + + Returns: + A dictionary of roomids to a list of receipts. + """ + + def f(txn): + if from_key: + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id > ? AND stream_id <= ? + ORDER BY stream_id DESC + LIMIT 100 + """ + txn.execute(sql, [from_key, to_key]) + else: + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id <= ? + ORDER BY stream_id DESC + LIMIT 100 + """ + + txn.execute(sql, [to_key]) + + return self.db_pool.cursor_to_dict(txn) + + txn_results = await self.db_pool.runInteraction( + "get_linearized_receipts_for_all_rooms", f + ) + + results = {} + for row in txn_results: + # We want a single event per room, since we want to batch the + # receipts by room, event and type. + room_event = results.setdefault( + row["room_id"], + {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, + ) + + # The content is of the form: + # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } + event_entry = room_event["content"].setdefault(row["event_id"], {}) + receipt_type = event_entry.setdefault(row["receipt_type"], {}) + + receipt_type[row["user_id"]] = db_to_json(row["data"]) + + return results + async def get_users_sent_receipts_between( self, last_id: int, current_id: int ) -> List[str]: @@ -358,18 +417,10 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): if receipt_type != "m.read": return - # Returns either an ObservableDeferred or the raw result - res = self.get_users_with_read_receipts_in_room.cache.get( + res = self.get_users_with_read_receipts_in_room.cache.get_immediate( room_id, None, update_metrics=False ) - # first handle the ObservableDeferred case - if isinstance(res, ObservableDeferred): - if res.has_called(): - res = res.get_result() - else: - res = None - if res and user_id in res: # We'd only be adding to the set, so no point invalidating if the # user is already there diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index a83df7759d..fedb8a6c26 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,32 +14,66 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import re -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import attr from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import SQLBaseStore +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.database import DatabasePool -from synapse.storage.types import Cursor +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.databases.main.stats import StatsStore +from synapse.storage.types import Connection, Cursor +from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 logger = logging.getLogger(__name__) -class RegistrationWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): +@attr.s(frozen=True, slots=True) +class TokenLookupResult: + """Result of looking up an access token. + + Attributes: + user_id: The user that this token authenticates as + is_guest + shadow_banned + token_id: The ID of the access token looked up + device_id: The device associated with the token, if any. + valid_until_ms: The timestamp the token expires, if any. + token_owner: The "owner" of the token. This is either the same as the + user, or a server admin who is logged in as the user. + """ + + user_id = attr.ib(type=str) + is_guest = attr.ib(type=bool, default=False) + shadow_banned = attr.ib(type=bool, default=False) + token_id = attr.ib(type=Optional[int], default=None) + device_id = attr.ib(type=Optional[str], default=None) + valid_until_ms = attr.ib(type=Optional[int], default=None) + token_owner = attr.ib(type=str) + + # Make the token owner default to the user ID, which is the common case. + @token_owner.default + def _default_token_owner(self): + return self.user_id + + +class RegistrationWorkerStore(CacheInvalidationWorkerStore): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config - self.clock = hs.get_clock() # Note: we don't check this sequence for consistency as we'd have to # call `find_max_generated_user_id_localpart` each time, which is @@ -48,6 +82,18 @@ class RegistrationWorkerStore(SQLBaseStore): database.engine, find_max_generated_user_id_localpart, "user_id_seq", ) + self._account_validity = hs.config.account_validity + if hs.config.run_background_tasks and self._account_validity.enabled: + self._clock.call_later( + 0.0, self._set_expiration_date_when_missing, + ) + + # Create a background job for culling expired 3PID validity tokens + if hs.config.run_background_tasks: + self._clock.looping_call( + self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS + ) + @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: return await self.db_pool.simple_select_one( @@ -81,21 +127,19 @@ class RegistrationWorkerStore(SQLBaseStore): if not info: return False - now = self.clock.time_msec() + now = self._clock.time_msec() trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms return is_trial @cached() - async def get_user_by_access_token(self, token: str) -> Optional[dict]: + async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]: """Get a user from the given access token. Args: token: The access token of a user. Returns: - None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`, - `valid_until_ms`. + None, if the token did not match, otherwise a `TokenLookupResult` """ return await self.db_pool.runInteraction( "get_user_by_access_token", self._query_for_auth, token @@ -225,13 +269,13 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_renewal_token_for_user", ) - async def get_users_expiring_soon(self) -> List[Dict[str, int]]: + async def get_users_expiring_soon(self) -> List[Dict[str, Any]]: """Selects users whose account will expire in the [now, now + renew_at] time window (see configuration for account_validity for information on what renew_at refers to). Returns: - A list of dictionaries mapping user ID to expiration time (in milliseconds). + A list of dictionaries, each with a user ID and expiration time (in milliseconds). """ def select_users_txn(txn, now_ms, renew_at): @@ -246,7 +290,7 @@ class RegistrationWorkerStore(SQLBaseStore): return await self.db_pool.runInteraction( "get_users_expiring_soon", select_users_txn, - self.clock.time_msec(), + self._clock.time_msec(), self.config.account_validity.renew_at, ) @@ -316,19 +360,24 @@ class RegistrationWorkerStore(SQLBaseStore): await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) - def _query_for_auth(self, txn, token): - sql = ( - "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id," - " access_tokens.device_id, access_tokens.valid_until_ms" - " FROM users" - " INNER JOIN access_tokens on users.name = access_tokens.user_id" - " WHERE token = ?" - ) + def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: + sql = """ + SELECT users.name as user_id, + users.is_guest, + users.shadow_banned, + access_tokens.id as token_id, + access_tokens.device_id, + access_tokens.valid_until_ms, + access_tokens.user_id as token_owner + FROM users + INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id) + WHERE token = ? + """ txn.execute(sql, (token,)) rows = self.db_pool.cursor_to_dict(txn) if rows: - return rows[0] + return TokenLookupResult(**rows[0]) return None @@ -778,12 +827,111 @@ class RegistrationWorkerStore(SQLBaseStore): "delete_threepid_session", delete_threepid_session_txn ) + @wrap_as_background_process("cull_expired_threepid_validation_tokens") + async def cull_expired_threepid_validation_tokens(self) -> None: + """Remove threepid validation tokens with expiry dates that have passed""" + + def cull_expired_threepid_validation_tokens_txn(txn, ts): + sql = """ + DELETE FROM threepid_validation_token WHERE + expires < ? + """ + txn.execute(sql, (ts,)) + + await self.db_pool.runInteraction( + "cull_expired_threepid_validation_tokens", + cull_expired_threepid_validation_tokens_txn, + self._clock.time_msec(), + ) + + @wrap_as_background_process("account_validity_set_expiration_dates") + async def _set_expiration_date_when_missing(self): + """ + Retrieves the list of registered users that don't have an expiration date, and + adds an expiration date for each of them. + """ + + def select_users_with_no_expiration_date_txn(txn): + """Retrieves the list of registered users with no expiration date from the + database, filtering out deactivated users. + """ + sql = ( + "SELECT users.name FROM users" + " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" + " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" + ) + txn.execute(sql, []) + + res = self.db_pool.cursor_to_dict(txn) + if res: + for user in res: + self.set_expiration_date_for_user_txn( + txn, user["name"], use_delta=True + ) + + await self.db_pool.runInteraction( + "get_users_with_no_expiration_date", + select_users_with_no_expiration_date_txn, + ) + + def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): + """Sets an expiration date to the account with the given user ID. + + Args: + user_id (str): User ID to set an expiration date for. + use_delta (bool): If set to False, the expiration date for the user will be + now + validity period. If set to True, this expiration date will be a + random value in the [now + period - d ; now + period] range, d being a + delta equal to 10% of the validity period. + """ + now_ms = self._clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + + if use_delta: + expiration_ts = self.rand.randrange( + expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts, + ) + + self.db_pool.simple_upsert_txn( + txn, + "account_validity", + keyvalues={"user_id": user_id}, + values={"expiration_ts_ms": expiration_ts, "email_sent": False}, + ) + + async def get_user_pending_deactivation(self) -> Optional[str]: + """ + Gets one user from the table of users waiting to be parted from all the rooms + they're in. + """ + return await self.db_pool.simple_select_one_onecol( + "users_pending_deactivation", + keyvalues={}, + retcol="user_id", + allow_none=True, + desc="get_users_pending_deactivation", + ) + + async def del_user_pending_deactivation(self, user_id: str) -> None: + """ + Removes the given user to the table of users who need to be parted from all the + rooms they're in, effectively marking that user as fully deactivated. + """ + # XXX: This should be simple_delete_one but we failed to put a unique index on + # the table, so somehow duplicate entries have ended up in it. + await self.db_pool.simple_delete( + "users_pending_deactivation", + keyvalues={"user_id": user_id}, + desc="del_user_pending_deactivation", + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) - self.clock = hs.get_clock() + self._clock = hs.get_clock() self.config = hs.config self.db_pool.updates.register_background_index_update( @@ -906,32 +1054,55 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): return 1 + async def set_user_deactivated_status( + self, user_id: str, deactivated: bool + ) -> None: + """Set the `deactivated` property for the provided user to the provided value. -class RegistrationStore(RegistrationBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) + Args: + user_id: The ID of the user to set the status for. + deactivated: The value to set for `deactivated`. + """ - self._account_validity = hs.config.account_validity - self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors + await self.db_pool.runInteraction( + "set_user_deactivated_status", + self.set_user_deactivated_status_txn, + user_id, + deactivated, + ) - if self._account_validity.enabled: - self._clock.call_later( - 0.0, - run_as_background_process, - "account_validity_set_expiration_dates", - self._set_expiration_date_when_missing, - ) + def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool): + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"deactivated": 1 if deactivated else 0}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_deactivated_status, (user_id,) + ) + txn.call_after(self.is_guest.invalidate, (user_id,)) + + @cached() + async def is_guest(self, user_id: str) -> bool: + res = await self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="is_guest", + allow_none=True, + desc="is_guest", + ) + + return res if res else False - # Create a background job for culling expired 3PID validity tokens - def start_cull(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "cull_expired_threepid_validation_tokens", - self.cull_expired_threepid_validation_tokens, - ) - hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) +class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + super().__init__(database, db_conn, hs) + + self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors + + self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") async def add_access_token_to_user( self, @@ -939,7 +1110,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): token: str, device_id: Optional[str], valid_until_ms: Optional[int], - ) -> None: + puppets_user_id: Optional[str] = None, + ) -> int: """Adds an access token for the given user. Args: @@ -949,6 +1121,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): valid_until_ms: when the token is valid until. None for no expiry. Raises: StoreError if there was a problem adding this. + Returns: + The token ID """ next_id = self._access_tokens_id_gen.get_next() @@ -960,10 +1134,43 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "token": token, "device_id": device_id, "valid_until_ms": valid_until_ms, + "puppets_user_id": puppets_user_id, }, desc="add_access_token_to_user", ) + return next_id + + def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: + old_device_id = self.db_pool.simple_select_one_onecol_txn( + txn, "access_tokens", {"token": token}, "device_id" + ) + + self.db_pool.simple_update_txn( + txn, "access_tokens", {"token": token}, {"device_id": device_id} + ) + + self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,)) + + return old_device_id + + async def set_device_for_access_token(self, token: str, device_id: str) -> str: + """Sets the device ID associated with an access token. + + Args: + token: The access token to modify. + device_id: The new device ID. + Returns: + The old device ID associated with the access token. + """ + + return await self.db_pool.runInteraction( + "set_device_for_access_token", + self._set_device_for_access_token_txn, + token, + device_id, + ) + async def register_user( self, user_id: str, @@ -1014,19 +1221,19 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def _register_user( self, txn, - user_id, - password_hash, - was_guest, - make_guest, - appservice_id, - create_profile_with_displayname, - admin, - user_type, - shadow_banned, + user_id: str, + password_hash: Optional[str], + was_guest: bool, + make_guest: bool, + appservice_id: Optional[str], + create_profile_with_displayname: Optional[str], + admin: bool, + user_type: Optional[str], + shadow_banned: bool, ): user_id_obj = UserID.from_string(user_id) - now = int(self.clock.time()) + now = int(self._clock.time()) try: if was_guest: @@ -1121,7 +1328,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="record_user_external_id", ) - async def user_set_password_hash(self, user_id: str, password_hash: str) -> None: + async def user_set_password_hash( + self, user_id: str, password_hash: Optional[str] + ) -> None: """ NB. This does *not* evict any cache because the one use for this removes most of the entries subsequently anyway so it would be @@ -1248,18 +1457,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): await self.db_pool.runInteraction("delete_access_token", f) - @cached() - async def is_guest(self, user_id: str) -> bool: - res = await self.db_pool.simple_select_one_onecol( - table="users", - keyvalues={"name": user_id}, - retcol="is_guest", - allow_none=True, - desc="is_guest", - ) - - return res if res else False - async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're @@ -1271,32 +1468,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="add_user_pending_deactivation", ) - async def del_user_pending_deactivation(self, user_id: str) -> None: - """ - Removes the given user to the table of users who need to be parted from all the - rooms they're in, effectively marking that user as fully deactivated. - """ - # XXX: This should be simple_delete_one but we failed to put a unique index on - # the table, so somehow duplicate entries have ended up in it. - await self.db_pool.simple_delete( - "users_pending_deactivation", - keyvalues={"user_id": user_id}, - desc="del_user_pending_deactivation", - ) - - async def get_user_pending_deactivation(self) -> Optional[str]: - """ - Gets one user from the table of users waiting to be parted from all the rooms - they're in. - """ - return await self.db_pool.simple_select_one_onecol( - "users_pending_deactivation", - keyvalues={}, - retcol="user_id", - allow_none=True, - desc="get_users_pending_deactivation", - ) - async def validate_threepid_session( self, session_id: str, client_secret: str, token: str, current_ts: int ) -> Optional[str]: @@ -1379,7 +1550,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, - updatevalues={"validated_at": self.clock.time_msec()}, + updatevalues={"validated_at": self._clock.time_msec()}, ) return next_link @@ -1447,106 +1618,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): start_or_continue_validation_session_txn, ) - async def cull_expired_threepid_validation_tokens(self) -> None: - """Remove threepid validation tokens with expiry dates that have passed""" - - def cull_expired_threepid_validation_tokens_txn(txn, ts): - sql = """ - DELETE FROM threepid_validation_token WHERE - expires < ? - """ - txn.execute(sql, (ts,)) - - await self.db_pool.runInteraction( - "cull_expired_threepid_validation_tokens", - cull_expired_threepid_validation_tokens_txn, - self.clock.time_msec(), - ) - - async def set_user_deactivated_status( - self, user_id: str, deactivated: bool - ) -> None: - """Set the `deactivated` property for the provided user to the provided value. - - Args: - user_id: The ID of the user to set the status for. - deactivated: The value to set for `deactivated`. - """ - - await self.db_pool.runInteraction( - "set_user_deactivated_status", - self.set_user_deactivated_status_txn, - user_id, - deactivated, - ) - - def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self.db_pool.simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"deactivated": 1 if deactivated else 0}, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_deactivated_status, (user_id,) - ) - txn.call_after(self.is_guest.invalidate, (user_id,)) - - async def _set_expiration_date_when_missing(self): - """ - Retrieves the list of registered users that don't have an expiration date, and - adds an expiration date for each of them. - """ - - def select_users_with_no_expiration_date_txn(txn): - """Retrieves the list of registered users with no expiration date from the - database, filtering out deactivated users. - """ - sql = ( - "SELECT users.name FROM users" - " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" - " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" - ) - txn.execute(sql, []) - - res = self.db_pool.cursor_to_dict(txn) - if res: - for user in res: - self.set_expiration_date_for_user_txn( - txn, user["name"], use_delta=True - ) - - await self.db_pool.runInteraction( - "get_users_with_no_expiration_date", - select_users_with_no_expiration_date_txn, - ) - - def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): - """Sets an expiration date to the account with the given user ID. - - Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be - now + validity period. If set to True, this expiration date will be a - random value in the [now + period - d ; now + period] range, d being a - delta equal to 10% of the validity period. - """ - now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period - - if use_delta: - expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, - expiration_ts, - ) - - self.db_pool.simple_upsert_txn( - txn, - "account_validity", - keyvalues={"user_id": user_id}, - values={"expiration_ts_ms": expiration_ts, "email_sent": False}, - ) - def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 3c7630857f..6b89db15c9 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -192,6 +192,18 @@ class RoomWorkerStore(SQLBaseStore): "count_public_rooms", _count_public_rooms_txn ) + async def get_room_count(self) -> int: + """Retrieve the total number of rooms. + """ + + def f(txn): + sql = "SELECT count(*) FROM rooms" + txn.execute(sql) + row = txn.fetchone() + return row[0] or 0 + + return await self.db_pool.runInteraction("get_rooms", f) + async def get_largest_public_rooms( self, network_tuple: Optional[ThirdPartyInstanceID], @@ -857,6 +869,89 @@ class RoomWorkerStore(SQLBaseStore): "get_all_new_public_rooms", get_all_new_public_rooms ) + async def get_rooms_for_retention_period_in_range( + self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False + ) -> Dict[str, dict]: + """Retrieves all of the rooms within the given retention range. + + Optionally includes the rooms which don't have a retention policy. + + Args: + min_ms: Duration in milliseconds that define the lower limit of + the range to handle (exclusive). If None, doesn't set a lower limit. + max_ms: Duration in milliseconds that define the upper limit of + the range to handle (inclusive). If None, doesn't set an upper limit. + include_null: Whether to include rooms which retention policy is NULL + in the returned set. + + Returns: + The rooms within this range, along with their retention + policy. The key is "room_id", and maps to a dict describing the retention + policy associated with this room ID. The keys for this nested dict are + "min_lifetime" (int|None), and "max_lifetime" (int|None). + """ + + def get_rooms_for_retention_period_in_range_txn(txn): + range_conditions = [] + args = [] + + if min_ms is not None: + range_conditions.append("max_lifetime > ?") + args.append(min_ms) + + if max_ms is not None: + range_conditions.append("max_lifetime <= ?") + args.append(max_ms) + + # Do a first query which will retrieve the rooms that have a retention policy + # in their current state. + sql = """ + SELECT room_id, min_lifetime, max_lifetime FROM room_retention + INNER JOIN current_state_events USING (event_id, room_id) + """ + + if len(range_conditions): + sql += " WHERE (" + " AND ".join(range_conditions) + ")" + + if include_null: + sql += " OR max_lifetime IS NULL" + + txn.execute(sql, args) + + rows = self.db_pool.cursor_to_dict(txn) + rooms_dict = {} + + for row in rows: + rooms_dict[row["room_id"]] = { + "min_lifetime": row["min_lifetime"], + "max_lifetime": row["max_lifetime"], + } + + if include_null: + # If required, do a second query that retrieves all of the rooms we know + # of so we can handle rooms with no retention policy. + sql = "SELECT DISTINCT room_id FROM current_state_events" + + txn.execute(sql) + + rows = self.db_pool.cursor_to_dict(txn) + + # If a room isn't already in the dict (i.e. it doesn't have a retention + # policy in its state), add it with a null policy. + for row in rows: + if row["room_id"] not in rooms_dict: + rooms_dict[row["room_id"]] = { + "min_lifetime": None, + "max_lifetime": None, + } + + return rooms_dict + + return await self.db_pool.runInteraction( + "get_rooms_for_retention_period_in_range", + get_rooms_for_retention_period_in_range_txn, + ) + class RoomBackgroundUpdateStore(SQLBaseStore): REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" @@ -1145,13 +1240,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion): + async def maybe_store_room_on_outlier_membership( + self, room_id: str, room_version: RoomVersion + ): """ - When we receive an invite over federation, store the version of the room if we - don't already know the room version. + When we receive an invite or any other event over federation that may relate to a room + we are not in, store the version of the room if we don't already know the room version. """ await self.db_pool.simple_upsert( - desc="maybe_store_room_on_invite", + desc="maybe_store_room_on_outlier_membership", table="rooms", keyvalues={"room_id": room_id}, values={}, @@ -1292,18 +1389,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) self.hs.get_notifier().on_new_replication_data() - async def get_room_count(self) -> int: - """Retrieve the total number of rooms. - """ - - def f(txn): - sql = "SELECT count(*) FROM rooms" - txn.execute(sql) - row = txn.fetchone() - return row[0] or 0 - - return await self.db_pool.runInteraction("get_rooms", f) - async def add_event_report( self, room_id: str, @@ -1328,6 +1413,65 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): desc="add_event_report", ) + async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]: + """Retrieve an event report + + Args: + report_id: ID of reported event in database + Returns: + event_report: json list of information from event report + """ + + def _get_event_report_txn(txn, report_id): + + sql = """ + SELECT + er.id, + er.received_ts, + er.room_id, + er.event_id, + er.user_id, + er.content, + events.sender, + room_stats_state.canonical_alias, + room_stats_state.name, + event_json.json AS event_json + FROM event_reports AS er + LEFT JOIN events + ON events.event_id = er.event_id + JOIN event_json + ON event_json.event_id = er.event_id + JOIN room_stats_state + ON room_stats_state.room_id = er.room_id + WHERE er.id = ? + """ + + txn.execute(sql, [report_id]) + row = txn.fetchone() + + if not row: + return None + + event_report = { + "id": row[0], + "received_ts": row[1], + "room_id": row[2], + "event_id": row[3], + "user_id": row[4], + "score": db_to_json(row[5]).get("score"), + "reason": db_to_json(row[5]).get("reason"), + "sender": row[6], + "canonical_alias": row[7], + "name": row[8], + "event_json": db_to_json(row[9]), + } + + return event_report + + return await self.db_pool.runInteraction( + "get_event_report", _get_event_report_txn, report_id + ) + async def get_event_reports_paginate( self, start: int, @@ -1385,18 +1529,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): er.room_id, er.event_id, er.user_id, - er.reason, er.content, events.sender, - room_aliases.room_alias, - event_json.json AS event_json + room_stats_state.canonical_alias, + room_stats_state.name FROM event_reports AS er - LEFT JOIN room_aliases - ON room_aliases.room_id = er.room_id - JOIN events + LEFT JOIN events ON events.event_id = er.event_id - JOIN event_json - ON event_json.event_id = er.event_id + JOIN room_stats_state + ON room_stats_state.room_id = er.room_id {where_clause} ORDER BY er.received_ts {order} LIMIT ? @@ -1407,15 +1548,29 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): args += [limit, start] txn.execute(sql, args) - event_reports = self.db_pool.cursor_to_dict(txn) - - if count > 0: - for row in event_reports: - try: - row["content"] = db_to_json(row["content"]) - row["event_json"] = db_to_json(row["event_json"]) - except Exception: - continue + + event_reports = [] + for row in txn: + try: + s = db_to_json(row[5]).get("score") + r = db_to_json(row[5]).get("reason") + except Exception: + logger.error("Unable to parse json from event_reports: %s", row[0]) + continue + event_reports.append( + { + "id": row[0], + "received_ts": row[1], + "room_id": row[2], + "event_id": row[3], + "user_id": row[4], + "score": s, + "reason": r, + "sender": row[6], + "canonical_alias": row[7], + "name": row[8], + } + ) return event_reports, count @@ -1446,88 +1601,3 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self.is_room_blocked, (room_id,), ) - - async def get_rooms_for_retention_period_in_range( - self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False - ) -> Dict[str, dict]: - """Retrieves all of the rooms within the given retention range. - - Optionally includes the rooms which don't have a retention policy. - - Args: - min_ms: Duration in milliseconds that define the lower limit of - the range to handle (exclusive). If None, doesn't set a lower limit. - max_ms: Duration in milliseconds that define the upper limit of - the range to handle (inclusive). If None, doesn't set an upper limit. - include_null: Whether to include rooms which retention policy is NULL - in the returned set. - - Returns: - The rooms within this range, along with their retention - policy. The key is "room_id", and maps to a dict describing the retention - policy associated with this room ID. The keys for this nested dict are - "min_lifetime" (int|None), and "max_lifetime" (int|None). - """ - - def get_rooms_for_retention_period_in_range_txn(txn): - range_conditions = [] - args = [] - - if min_ms is not None: - range_conditions.append("max_lifetime > ?") - args.append(min_ms) - - if max_ms is not None: - range_conditions.append("max_lifetime <= ?") - args.append(max_ms) - - # Do a first query which will retrieve the rooms that have a retention policy - # in their current state. - sql = """ - SELECT room_id, min_lifetime, max_lifetime FROM room_retention - INNER JOIN current_state_events USING (event_id, room_id) - """ - - if len(range_conditions): - sql += " WHERE (" + " AND ".join(range_conditions) + ")" - - if include_null: - sql += " OR max_lifetime IS NULL" - - txn.execute(sql, args) - - rows = self.db_pool.cursor_to_dict(txn) - rooms_dict = {} - - for row in rows: - rooms_dict[row["room_id"]] = { - "min_lifetime": row["min_lifetime"], - "max_lifetime": row["max_lifetime"], - } - - if include_null: - # If required, do a second query that retrieves all of the rooms we know - # of so we can handle rooms with no retention policy. - sql = "SELECT DISTINCT room_id FROM current_state_events" - - txn.execute(sql) - - rows = self.db_pool.cursor_to_dict(txn) - - # If a room isn't already in the dict (i.e. it doesn't have a retention - # policy in its state), add it with a null policy. - for row in rows: - if row["room_id"] not in rooms_dict: - rooms_dict[row["room_id"]] = { - "min_lifetime": None, - "max_lifetime": None, - } - - return rooms_dict - - rooms = await self.db_pool.runInteraction( - "get_rooms_for_retention_period_in_range", - get_rooms_for_retention_period_in_range_txn, - ) - - return rooms diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 86ffe2479e..dcdaf09682 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -14,19 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import ( - LoggingTransaction, - SQLBaseStore, - db_to_json, - make_in_list_sql_clause, +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, ) +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine @@ -60,27 +58,22 @@ class RoomMemberWorkerStore(EventsWorkerStore): # background update still running? self._current_state_events_membership_up_to_date = False - txn = LoggingTransaction( - db_conn.cursor(), - name="_check_safe_current_state_events_membership_updated", - database_engine=self.database_engine, + txn = db_conn.cursor( + txn_name="_check_safe_current_state_events_membership_updated" ) self._check_safe_current_state_events_membership_updated_txn(txn) txn.close() - if self.hs.config.metrics_flags.known_servers: + if ( + self.hs.config.run_background_tasks + and self.hs.config.metrics_flags.known_servers + ): self._known_servers_count = 1 self.hs.get_clock().looping_call( - run_as_background_process, - 60 * 1000, - "_count_known_servers", - self._count_known_servers, + self._count_known_servers, 60 * 1000, ) self.hs.get_clock().call_later( - 1000, - run_as_background_process, - "_count_known_servers", - self._count_known_servers, + 1000, self._count_known_servers, ) LaterGauge( "synapse_federation_known_servers", @@ -89,6 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): lambda: self._known_servers_count, ) + @wrap_as_background_process("_count_known_servers") async def _count_known_servers(self): """ Count the servers that this server knows about. @@ -356,6 +350,38 @@ class RoomMemberWorkerStore(EventsWorkerStore): return results + async def get_local_current_membership_for_user_in_room( + self, user_id: str, room_id: str + ) -> Tuple[Optional[str], Optional[str]]: + """Retrieve the current local membership state and event ID for a user in a room. + + Args: + user_id: The ID of the user. + room_id: The ID of the room. + + Returns: + A tuple of (membership_type, event_id). Both will be None if a + room_id/user_id pair is not found. + """ + # Paranoia check. + if not self.hs.is_mine_id(user_id): + raise Exception( + "Cannot call 'get_local_current_membership_for_user_in_room' on " + "non-local user %s" % (user_id,), + ) + + results_dict = await self.db_pool.simple_select_one( + "local_current_membership", + {"room_id": room_id, "user_id": user_id}, + ("membership", "event_id"), + allow_none=True, + desc="get_local_current_membership_for_user_in_room", + ) + if not results_dict: + return None, None + + return results_dict.get("membership"), results_dict.get("event_id") + @cached(max_entries=500000, iterable=True) async def get_rooms_for_user_with_stream_ordering( self, user_id: str @@ -535,7 +561,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # If we do then we can reuse that result and simply update it with # any membership changes in `delta_ids` if context.prev_group and context.delta_ids: - prev_res = self._get_joined_users_from_context.cache.get( + prev_res = self._get_joined_users_from_context.cache.get_immediate( (room_id, context.prev_group), None ) if prev_res and isinstance(prev_res, dict): diff --git a/synapse/storage/databases/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py index 3edfcfd783..45b846e6a7 100644 --- a/synapse/storage/databases/main/schema/delta/20/pushers.py +++ b/synapse/storage/databases/main/schema/delta/20/pushers.py @@ -66,16 +66,15 @@ def run_create(cur, database_engine, *args, **kwargs): row[8] = bytes(row[8]).decode("utf-8") row[11] = bytes(row[11]).decode("utf-8") cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_token, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + ) values (%s) + """ + % (",".join(["?" for _ in range(len(row))])), row, ) count += 1 diff --git a/synapse/storage/databases/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py index ee675e71ff..21f57825d4 100644 --- a/synapse/storage/databases/main/schema/delta/25/fts.py +++ b/synapse/storage/databases/main/schema/delta/25/fts.py @@ -71,8 +71,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_search", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py index b7972cfa8e..1c6058063f 100644 --- a/synapse/storage/databases/main/schema/delta/27/ts.py +++ b/synapse/storage/databases/main/schema/delta/27/ts.py @@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_origin_server_ts", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py index b42c02710a..7f08fabe9f 100644 --- a/synapse/storage/databases/main/schema/delta/30/as_users.py +++ b/synapse/storage/databases/main/schema/delta/30/as_users.py @@ -59,9 +59,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n)) for chunk in user_chunks: cur.execute( - database_engine.convert_param_style( - "UPDATE users SET appservice_id = ? WHERE name IN (%s)" - % (",".join("?" for _ in chunk),) - ), + "UPDATE users SET appservice_id = ? WHERE name IN (%s)" + % (",".join("?" for _ in chunk),), [as_id] + chunk, ) diff --git a/synapse/storage/databases/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py index 9bb504aad5..5be81c806a 100644 --- a/synapse/storage/databases/main/schema/delta/31/pushers.py +++ b/synapse/storage/databases/main/schema/delta/31/pushers.py @@ -65,16 +65,15 @@ def run_create(cur, database_engine, *args, **kwargs): row = list(row) row[12] = token_to_stream_ordering(row[12]) cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_stream_ordering, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_stream_ordering, last_success, + failing_since + ) values (%s) + """ + % (",".join(["?" for _ in range(len(row))])), row, ) count += 1 diff --git a/synapse/storage/databases/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py index 63b757ade6..b84c844e3a 100644 --- a/synapse/storage/databases/main/schema/delta/31/search_update.py +++ b/synapse/storage/databases/main/schema/delta/31/search_update.py @@ -55,8 +55,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_search_order", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py index a3e81eeac7..e928c66a8f 100644 --- a/synapse/storage/databases/main/schema/delta/33/event_fields.py +++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py @@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_fields_sender_url", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py index a26057dfb6..ad875c733a 100644 --- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py @@ -23,8 +23,5 @@ def run_create(cur, database_engine, *args, **kwargs): def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute( - database_engine.convert_param_style( - "UPDATE remote_media_cache SET last_access_ts = ?" - ), - (int(time.time() * 1000),), + "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),), ) diff --git a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py index 1de8b54961..bb7296852a 100644 --- a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py +++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py @@ -1,6 +1,8 @@ import logging +from io import StringIO from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import execute_statements_from_stream logger = logging.getLogger(__name__) @@ -46,7 +48,4 @@ def run_create(cur, database_engine, *args, **kwargs): select_clause, ) - if isinstance(database_engine, PostgresEngine): - cur.execute(sql) - else: - cur.executescript(sql) + execute_statements_from_stream(cur, StringIO(sql)) diff --git a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py index 63b5acdcf7..44917f0a2e 100644 --- a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py +++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py @@ -68,7 +68,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): INNER JOIN room_memberships AS r USING (event_id) WHERE type = 'm.room.member' AND state_key LIKE ? """ - sql = database_engine.convert_param_style(sql) cur.execute(sql, ("%:" + config.server_name,)) cur.execute( diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres index b64926e9c9..3275ae2b20 100644 --- a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres +++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres @@ -20,14 +20,14 @@ */ -- add new index that includes method to local media -INSERT INTO background_updates (update_name, progress_json) VALUES - ('local_media_repository_thumbnails_method_idx', '{}'); +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5807, 'local_media_repository_thumbnails_method_idx', '{}'); -- add new index that includes method to remote media -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx'); +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (5807, 'remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx'); -- drop old index -INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES - ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx'); +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (5807, 'media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx'); diff --git a/synapse/storage/databases/main/schema/delta/58/11dehydration.sql b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql new file mode 100644 index 0000000000..7851a0a825 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql @@ -0,0 +1,20 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS dehydrated_devices( + user_id TEXT NOT NULL PRIMARY KEY, + device_id TEXT NOT NULL, + device_data TEXT NOT NULL -- JSON-encoded client-defined data +); diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql new file mode 100644 index 0000000000..4ed981dbf8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql @@ -0,0 +1,24 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json ( + user_id TEXT NOT NULL, -- The user this fallback key is for. + device_id TEXT NOT NULL, -- The device this fallback key is for. + algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for. + key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. + key_json TEXT NOT NULL, -- The key as a JSON blob. + used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not. + CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm) +); diff --git a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql index cade5dcca8..fd733adf13 100644 --- a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql +++ b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql @@ -28,5 +28,5 @@ -- functionality as the old one. This effectively restarts the background job -- from the beginning, without running it twice in a row, supporting both -- upgrade usecases. -INSERT INTO background_updates (update_name, progress_json) VALUES - ('populate_stats_process_rooms_2', '{}'); +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5812, 'populate_stats_process_rooms_2', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres new file mode 100644 index 0000000000..841186b826 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres @@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A unique and immutable mapping between instance name and an integer ID. This +-- lets us refer to instances via a small ID in e.g. stream tokens, without +-- having to encode the full name. +CREATE TABLE IF NOT EXISTS instance_map ( + instance_id SERIAL PRIMARY KEY, + instance_name TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name); diff --git a/synapse/storage/databases/main/schema/delta/58/19txn_id.sql b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql new file mode 100644 index 0000000000..b2454121a8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql @@ -0,0 +1,40 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A map of recent events persisted with transaction IDs. Used to deduplicate +-- send event requests with the same transaction ID. +-- +-- Note: transaction IDs are scoped to the room ID/user ID/access token that was +-- used to make the request. +-- +-- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the +-- events or access token we don't want to try and de-duplicate the event. +CREATE TABLE IF NOT EXISTS event_txn_id ( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + token_id BIGINT NOT NULL, + txn_id TEXT NOT NULL, + inserted_ts BIGINT NOT NULL, + FOREIGN KEY (event_id) + REFERENCES events (event_id) ON DELETE CASCADE, + FOREIGN KEY (token_id) + REFERENCES access_tokens (id) ON DELETE CASCADE +); + +CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id); +CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id); +CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts); diff --git a/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql new file mode 100644 index 0000000000..ad1f481428 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE current_state_delta_stream ADD COLUMN instance_name TEXT; +ALTER TABLE ex_outlier_stream ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql b/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql new file mode 100644 index 0000000000..b0b5dcddce --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + -- Add new column to user_daily_visits to track user agent +ALTER TABLE user_daily_visits + ADD COLUMN user_agent TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql b/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql new file mode 100644 index 0000000000..7b84a207fd --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE application_services_state ADD COLUMN read_receipt_stream_id INT; +ALTER TABLE application_services_state ADD COLUMN presence_stream_id INT; \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql b/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql new file mode 100644 index 0000000000..01ea6eddcf --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql @@ -0,0 +1 @@ +DROP TABLE device_max_stream_id; diff --git a/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql new file mode 100644 index 0000000000..00a9431a97 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Whether the access token is an admin token for controlling another user. +ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql new file mode 100644 index 0000000000..e1a35be831 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql @@ -0,0 +1,2 @@ +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5822, 'users_have_local_media', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql new file mode 100644 index 0000000000..75c3915a94 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5823, 'e2e_cross_signing_keys_idx', '{}'); diff --git a/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql new file mode 100644 index 0000000000..8a39d54aed --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql @@ -0,0 +1,19 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this index is essentially redundant. The only time it was ever used was when purging +-- rooms - and Synapse 1.24 will change that. + +DROP INDEX IF EXISTS event_json_room_id; diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 5beb302be3..0cdb3ec1f7 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -16,15 +16,18 @@ import logging from collections import Counter +from enum import Enum from itertools import chain from typing import Any, Dict, List, Optional, Tuple from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import StoreError from synapse.storage.database import DatabasePool from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -59,6 +62,23 @@ TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} +class UserSortOrder(Enum): + """ + Enum to define the sorting method used when returning users + with get_users_media_usage_paginate + + MEDIA_LENGTH = ordered by size of uploaded media. Smallest to largest. + MEDIA_COUNT = ordered by number of uploaded media. Smallest to largest. + USER_ID = ordered alphabetically by `user_id`. + DISPLAYNAME = ordered alphabetically by `displayname` + """ + + MEDIA_LENGTH = "media_length" + MEDIA_COUNT = "media_count" + USER_ID = "user_id" + DISPLAYNAME = "displayname" + + class StatsStore(StateDeltasStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -882,3 +902,110 @@ class StatsStore(StateDeltasStore): complete_with_stream_id=pos, absolute_field_overrides={"joined_rooms": joined_rooms}, ) + + async def get_users_media_usage_paginate( + self, + start: int, + limit: int, + from_ts: Optional[int] = None, + until_ts: Optional[int] = None, + order_by: Optional[UserSortOrder] = UserSortOrder.USER_ID.value, + direction: Optional[str] = "f", + search_term: Optional[str] = None, + ) -> Tuple[List[JsonDict], Dict[str, int]]: + """Function to retrieve a paginated list of users and their uploaded local media + (size and number). This will return a json list of users and the + total number of users matching the filter criteria. + + Args: + start: offset to begin the query from + limit: number of rows to retrieve + from_ts: request only media that are created later than this timestamp (ms) + until_ts: request only media that are created earlier than this timestamp (ms) + order_by: the sort order of the returned list + direction: sort ascending or descending + search_term: a string to filter user names by + Returns: + A list of user dicts and an integer representing the total number of + users that exist given this query + """ + + def get_users_media_usage_paginate_txn(txn): + filters = [] + args = [self.hs.config.server_name] + + if search_term: + filters.append("(lmr.user_id LIKE ? OR displayname LIKE ?)") + args.extend(["@%" + search_term + "%:%", "%" + search_term + "%"]) + + if from_ts: + filters.append("created_ts >= ?") + args.extend([from_ts]) + if until_ts: + filters.append("created_ts <= ?") + args.extend([until_ts]) + + # Set ordering + if UserSortOrder(order_by) == UserSortOrder.MEDIA_LENGTH: + order_by_column = "media_length" + elif UserSortOrder(order_by) == UserSortOrder.MEDIA_COUNT: + order_by_column = "media_count" + elif UserSortOrder(order_by) == UserSortOrder.USER_ID: + order_by_column = "lmr.user_id" + elif UserSortOrder(order_by) == UserSortOrder.DISPLAYNAME: + order_by_column = "displayname" + else: + raise StoreError( + 500, "Incorrect value for order_by provided: %s" % order_by + ) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + sql_base = """ + FROM local_media_repository as lmr + LEFT JOIN profiles AS p ON lmr.user_id = '@' || p.user_id || ':' || ? + {} + GROUP BY lmr.user_id, displayname + """.format( + where_clause + ) + + # SQLite does not support SELECT COUNT(*) OVER() + sql = """ + SELECT COUNT(*) FROM ( + SELECT lmr.user_id + {sql_base} + ) AS count_user_ids + """.format( + sql_base=sql_base, + ) + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = """ + SELECT + lmr.user_id, + displayname, + COUNT(lmr.user_id) as media_count, + SUM(media_length) as media_length + {sql_base} + ORDER BY {order_by_column} {order} + LIMIT ? OFFSET ? + """.format( + sql_base=sql_base, order_by_column=order_by_column, order=order, + ) + + args += [limit, start] + txn.execute(sql, args) + users = self.db_pool.cursor_to_dict(txn) + + return users, count + + return await self.db_pool.runInteraction( + "get_users_media_usage_paginate_txn", get_users_media_usage_paginate_txn + ) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 37249f1e3f..e3b9ff5ca6 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -53,7 +53,9 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import Collection, PersistedEventPosition, RoomStreamToken +from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache if TYPE_CHECKING: @@ -208,6 +210,55 @@ def _make_generic_sql_bound( ) +def _filter_results( + lower_token: Optional[RoomStreamToken], + upper_token: Optional[RoomStreamToken], + instance_name: str, + topological_ordering: int, + stream_ordering: int, +) -> bool: + """Returns True if the event persisted by the given instance at the given + topological/stream_ordering falls between the two tokens (taking a None + token to mean unbounded). + + Used to filter results from fetching events in the DB against the given + tokens. This is necessary to handle the case where the tokens include + position maps, which we handle by fetching more than necessary from the DB + and then filtering (rather than attempting to construct a complicated SQL + query). + """ + + event_historical_tuple = ( + topological_ordering, + stream_ordering, + ) + + if lower_token: + if lower_token.topological is not None: + # If these are historical tokens we compare the `(topological, stream)` + # tuples. + if event_historical_tuple <= lower_token.as_historical_tuple(): + return False + + else: + # If these are live tokens we compare the stream ordering against the + # writers stream position. + if stream_ordering <= lower_token.get_stream_pos_for_instance( + instance_name + ): + return False + + if upper_token: + if upper_token.topological is not None: + if upper_token.as_historical_tuple() < event_historical_tuple: + return False + else: + if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering: + return False + + return True + + def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create @@ -305,7 +356,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): raise NotImplementedError() def get_room_max_token(self) -> RoomStreamToken: - return RoomStreamToken(None, self.get_room_max_stream_ordering()) + """Get a `RoomStreamToken` that marks the current maximum persisted + position of the events stream. Useful to get a token that represents + "now". + + The token returned is a "live" token that may have an instance_map + component. + """ + + min_pos = self._stream_id_gen.get_current_token() + + positions = {} + if isinstance(self._stream_id_gen, MultiWriterIdGenerator): + # The `min_pos` is the minimum position that we know all instances + # have finished persisting to, so we only care about instances whose + # positions are ahead of that. (Instance positions can be behind the + # min position as there are times we can work out that the minimum + # position is ahead of the naive minimum across all current + # positions. See MultiWriterIdGenerator for details) + positions = { + i: p + for i, p in self._stream_id_gen.get_positions().items() + if p > min_pos + } + + return RoomStreamToken(None, min_pos, positions) async def get_room_events_stream_for_rooms( self, @@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if from_key == to_key: return [], from_key - from_id = from_key.stream - to_id = to_key.stream - - has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) + has_changed = self._events_stream_cache.has_entity_changed( + room_id, from_key.stream + ) if not has_changed: return [], from_key def f(txn): - sql = ( - "SELECT event_id, stream_ordering FROM events WHERE" - " room_id = ?" - " AND not outlier" - " AND stream_ordering > ? AND stream_ordering <= ?" - " ORDER BY stream_ordering %s LIMIT ?" - ) % (order,) - txn.execute(sql, (room_id, from_id, to_id, limit)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + sql = """ + SELECT event_id, instance_name, topological_ordering, stream_ordering + FROM events + WHERE + room_id = ? + AND not outlier + AND stream_ordering > ? AND stream_ordering <= ? + ORDER BY stream_ordering %s LIMIT ? + """ % ( + order, + ) + txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit)) + + rows = [ + _EventDictReturn(event_id, None, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + from_key, + to_key, + instance_name, + topological_ordering, + stream_ordering, + ) + ][:limit] return rows rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) @@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): [r.event_id for r in rows], get_prev_content=True ) - self._set_before_and_after(ret, rows, topo_order=from_id is None) + self._set_before_and_after(ret, rows, topo_order=False) if order.lower() == "desc": ret.reverse() @@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: - from_id = from_key.stream - to_id = to_key.stream - if from_key == to_key: return [] - if from_id: + if from_key: has_changed = self._membership_stream_cache.has_entity_changed( - user_id, int(from_id) + user_id, int(from_key.stream) ) if not has_changed: return [] def f(txn): - sql = ( - "SELECT m.event_id, stream_ordering FROM events AS e," - " room_memberships AS m" - " WHERE e.event_id = m.event_id" - " AND m.user_id = ?" - " AND e.stream_ordering > ? AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - ) - txn.execute(sql, (user_id, from_id, to_id)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + sql = """ + SELECT m.event_id, instance_name, topological_ordering, stream_ordering + FROM events AS e, room_memberships AS m + WHERE e.event_id = m.event_id + AND m.user_id = ? + AND e.stream_ordering > ? AND e.stream_ordering <= ? + ORDER BY e.stream_ordering ASC + """ + txn.execute(sql, (user_id, min_from_id, max_to_id,)) + + rows = [ + _EventDictReturn(event_id, None, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + from_key, + to_key, + instance_name, + topological_ordering, + stream_ordering, + ) + ] return rows @@ -546,7 +651,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_room_event_before_stream_ordering( self, room_id: str, stream_ordering: int - ) -> Tuple[int, int, str]: + ) -> Optional[Tuple[int, int, str]]: """Gets details of the first event in a room at or before a stream ordering Args: @@ -589,19 +694,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): ) return "t%d-%d" % (topo, token) - async def get_stream_id_for_event(self, event_id: str) -> int: - """The stream ID for an event - Args: - event_id: The id of the event to look up a stream token for. - Raises: - StoreError if the event wasn't in the database. - Returns: - A stream ID. - """ - return await self.db_pool.runInteraction( - "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id, - ) - def get_stream_id_for_event_txn( self, txn: LoggingTransaction, event_id: str, allow_none=False, ) -> int: @@ -979,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): else: order = "ASC" + # The bounds for the stream tokens are complicated by the fact + # that we need to handle the instance_map part of the tokens. We do this + # by fetching all events between the min stream token and the maximum + # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and + # then filtering the results. + if from_token.topological is not None: + from_bound = ( + from_token.as_historical_tuple() + ) # type: Tuple[Optional[int], int] + elif direction == "b": + from_bound = ( + None, + from_token.get_max_stream_pos(), + ) + else: + from_bound = ( + None, + from_token.stream, + ) + + to_bound = None # type: Optional[Tuple[Optional[int], int]] + if to_token: + if to_token.topological is not None: + to_bound = to_token.as_historical_tuple() + elif direction == "b": + to_bound = ( + None, + to_token.stream, + ) + else: + to_bound = ( + None, + to_token.get_max_stream_pos(), + ) + bounds = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token.as_tuple(), - to_token=to_token.as_tuple() if to_token else None, + from_token=from_bound, + to_token=to_bound, engine=self.database_engine, ) @@ -993,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): bounds += " AND " + filter_clause args.extend(filter_args) - args.append(int(limit)) + # We fetch more events as we'll filter the result set + args.append(int(limit) * 2) select_keywords = "SELECT" join_clause = "" @@ -1015,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): select_keywords += "DISTINCT" sql = """ - %(select_keywords)s event_id, topological_ordering, stream_ordering + %(select_keywords)s + event_id, instance_name, + topological_ordering, stream_ordering FROM events %(join_clause)s WHERE outlier = ? AND room_id = ? AND %(bounds)s @@ -1030,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): txn.execute(sql, args) - rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] + # Filter the result set. + rows = [ + _EventDictReturn(event_id, topological_ordering, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + lower_token=to_token if direction == "b" else from_token, + upper_token=from_token if direction == "b" else to_token, + instance_name=instance_name, + topological_ordering=topological_ordering, + stream_ordering=stream_ordering, + ) + ][:limit] if rows: topo = rows[-1].topological_ordering @@ -1095,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): return (events, token) + @cached() + async def get_id_for_instance(self, instance_name: str) -> int: + """Get a unique, immutable ID that corresponds to the given Synapse worker instance. + """ + + def _get_id_for_instance_txn(txn): + instance_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + retcol="instance_id", + allow_none=True, + ) + if instance_id is not None: + return instance_id + + # If we don't have an entry upsert one. + # + # We could do this before the first check, and rely on the cache for + # efficiency, but each UPSERT causes the next ID to increment which + # can quickly bloat the size of the generated IDs for new instances. + self.db_pool.simple_upsert_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + values={}, + ) + + return self.db_pool.simple_select_one_onecol_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + retcol="instance_id", + ) + + return await self.db_pool.runInteraction( + "get_id_for_instance", _get_id_for_instance_txn + ) + + @cached() + async def get_name_from_instance_id(self, instance_id: int) -> str: + """Get the instance name from an ID previously returned by + `get_id_for_instance`. + """ + + return await self.db_pool.simple_select_one_onecol( + table="instance_map", + keyvalues={"instance_id": instance_id}, + retcol="instance_name", + desc="get_name_from_instance_id", + ) + class StreamStore(StreamWorkerStore): def get_room_max_stream_ordering(self) -> int: diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 97aed1500e..59207cadd4 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -19,7 +19,7 @@ from typing import Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -43,15 +43,33 @@ _UpdateTransactionRow = namedtuple( SENTINEL = object() -class TransactionStore(SQLBaseStore): +class TransactionWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + if hs.config.run_background_tasks: + self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) + + @wrap_as_background_process("cleanup_transactions") + async def _cleanup_transactions(self) -> None: + now = self._clock.time_msec() + month_ago = now - 30 * 24 * 60 * 60 * 1000 + + def _cleanup_transactions_txn(txn): + txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) + + await self.db_pool.runInteraction( + "_cleanup_transactions", _cleanup_transactions_txn + ) + + +class TransactionStore(TransactionWorkerStore): """A collection of queries for handling PDUs. """ def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) - self._destination_retry_cache = ExpiringCache( cache_name="get_destination_retry_timings", clock=self._clock, @@ -190,42 +208,56 @@ class TransactionStore(SQLBaseStore): """ self._destination_retry_cache.pop(destination, None) - return await self.db_pool.runInteraction( - "set_destination_retry_timings", - self._set_destination_retry_timings, - destination, - failure_ts, - retry_last_ts, - retry_interval, - ) + if self.database_engine.can_native_upsert: + return await self.db_pool.runInteraction( + "set_destination_retry_timings", + self._set_destination_retry_timings_native, + destination, + failure_ts, + retry_last_ts, + retry_interval, + db_autocommit=True, # Safe as its a single upsert + ) + else: + return await self.db_pool.runInteraction( + "set_destination_retry_timings", + self._set_destination_retry_timings_emulated, + destination, + failure_ts, + retry_last_ts, + retry_interval, + ) - def _set_destination_retry_timings( + def _set_destination_retry_timings_native( self, txn, destination, failure_ts, retry_last_ts, retry_interval ): + assert self.database_engine.can_native_upsert + + # Upsert retry time interval if retry_interval is zero (i.e. we're + # resetting it) or greater than the existing retry interval. + # + # WARNING: This is executed in autocommit, so we shouldn't add any more + # SQL calls in here (without being very careful). + sql = """ + INSERT INTO destinations ( + destination, failure_ts, retry_last_ts, retry_interval + ) + VALUES (?, ?, ?, ?) + ON CONFLICT (destination) DO UPDATE SET + failure_ts = EXCLUDED.failure_ts, + retry_last_ts = EXCLUDED.retry_last_ts, + retry_interval = EXCLUDED.retry_interval + WHERE + EXCLUDED.retry_interval = 0 + OR destinations.retry_interval IS NULL + OR destinations.retry_interval < EXCLUDED.retry_interval + """ - if self.database_engine.can_native_upsert: - # Upsert retry time interval if retry_interval is zero (i.e. we're - # resetting it) or greater than the existing retry interval. - - sql = """ - INSERT INTO destinations ( - destination, failure_ts, retry_last_ts, retry_interval - ) - VALUES (?, ?, ?, ?) - ON CONFLICT (destination) DO UPDATE SET - failure_ts = EXCLUDED.failure_ts, - retry_last_ts = EXCLUDED.retry_last_ts, - retry_interval = EXCLUDED.retry_interval - WHERE - EXCLUDED.retry_interval = 0 - OR destinations.retry_interval IS NULL - OR destinations.retry_interval < EXCLUDED.retry_interval - """ - - txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) - - return + txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) + def _set_destination_retry_timings_emulated( + self, txn, destination, failure_ts, retry_last_ts, retry_interval + ): self.database_engine.lock_table(txn, "destinations") # We need to be careful here as the data may have changed from under us @@ -266,22 +298,6 @@ class TransactionStore(SQLBaseStore): }, ) - def _start_cleanup_transactions(self): - return run_as_background_process( - "cleanup_transactions", self._cleanup_transactions - ) - - async def _cleanup_transactions(self) -> None: - now = self._clock.time_msec() - month_ago = now - 30 * 24 * 60 * 60 * 1000 - - def _cleanup_transactions_txn(txn): - txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - - await self.db_pool.runInteraction( - "_cleanup_transactions", _cleanup_transactions_txn - ) - async def store_destination_rooms_entries( self, destinations: Iterable[str], room_id: str, stream_ordering: int, ) -> None: diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 3b9211a6d2..79b7ece330 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -288,8 +288,6 @@ class UIAuthWorkerStore(SQLBaseStore): ) return [(row["user_agent"], row["ip"]) for row in rows] - -class UIAuthStore(UIAuthWorkerStore): async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ Remove sessions which were last used earlier than the expiration time. @@ -339,3 +337,7 @@ class UIAuthStore(UIAuthWorkerStore): iterable=session_ids, keyvalues={}, ) + + +class UIAuthStore(UIAuthWorkerStore): + pass diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 5a390ff2f6..d87ceec6da 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -480,21 +480,16 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): user_id_tuples: iterable of 2-tuple of user IDs. """ - def _add_users_who_share_room_txn(txn): - self.db_pool.simple_upsert_many_txn( - txn, - table="users_who_share_private_rooms", - key_names=["user_id", "other_user_id", "room_id"], - key_values=[ - (user_id, other_user_id, room_id) - for user_id, other_user_id in user_id_tuples - ], - value_names=(), - value_values=None, - ) - - await self.db_pool.runInteraction( - "add_users_who_share_room", _add_users_who_share_room_txn + await self.db_pool.simple_upsert_many( + table="users_who_share_private_rooms", + key_names=["user_id", "other_user_id", "room_id"], + key_values=[ + (user_id, other_user_id, room_id) + for user_id, other_user_id in user_id_tuples + ], + value_names=(), + value_values=None, + desc="add_users_who_share_room", ) async def add_users_in_public_rooms( @@ -508,19 +503,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): user_ids """ - def _add_users_in_public_rooms_txn(txn): - - self.db_pool.simple_upsert_many_txn( - txn, - table="users_in_public_rooms", - key_names=["user_id", "room_id"], - key_values=[(user_id, room_id) for user_id in user_ids], - value_names=(), - value_values=None, - ) - - await self.db_pool.runInteraction( - "add_users_in_public_rooms", _add_users_in_public_rooms_txn + await self.db_pool.simple_upsert_many( + table="users_in_public_rooms", + key_names=["user_id", "room_id"], + key_values=[(user_id, room_id) for user_id in user_ids], + value_names=(), + value_values=None, + desc="add_users_in_public_rooms", ) async def delete_all_from_user_dir(self) -> None: |