diff options
Diffstat (limited to 'synapse/storage')
34 files changed, 883 insertions, 309 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index da3b99f93d..13de5f1f62 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -56,7 +56,7 @@ class SQLBaseStore(metaclass=ABCMeta): members_changed (iterable[str]): The user_ids of members that have changed """ - for host in set(get_domain_from_id(u) for u in members_changed): + for host in {get_domain_from_id(u) for u in members_changed}: self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) self._attempt_to_invalidate_cache("was_host_joined", (room_id, host)) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index bd547f35cf..eb1a7e5002 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -189,7 +189,7 @@ class BackgroundUpdater(object): keyvalues=None, retcols=("update_name", "depends_on"), ) - in_flight = set(update["update_name"] for update in updates) + in_flight = {update["update_name"] for update in updates} for update in updates: if update["depends_on"] not in in_flight: self._background_update_queue.append(update["update_name"]) diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 2700cca822..acca079f23 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -20,6 +20,7 @@ import logging import time from synapse.api.constants import PresenceState +from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( @@ -117,16 +118,6 @@ class DataStore( self._clock = hs.get_clock() self.database_engine = database.engine - all_users_native = are_all_users_on_domain( - db_conn.cursor(), database.engine, hs.hostname - ) - if not all_users_native: - raise Exception( - "Found users in database not native to %s!\n" - "You cannot changed a synapse server_name after it's been configured" - % (hs.hostname,) - ) - self._stream_id_gen = StreamIdGenerator( db_conn, "events", @@ -567,13 +558,26 @@ class DataStore( ) -def are_all_users_on_domain(txn, database_engine, domain): +def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig): + """Called before upgrading an existing database to check that it is broadly sane + compared with the configuration. + """ + domain = config.server_name + sql = database_engine.convert_param_style( "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" ) pat = "%:" + domain - txn.execute(sql, (pat,)) - num_not_matching = txn.fetchall()[0][0] + cur.execute(sql, (pat,)) + num_not_matching = cur.fetchall()[0][0] if num_not_matching == 0: - return True - return False + return + + raise Exception( + "Found users in database not native to %s!\n" + "You cannot changed a synapse server_name after it's been configured" + % (domain,) + ) + + +__all__ = ["DataStore", "check_database_before_upgrade"] diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index b2f39649fd..efbc06c796 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -135,7 +135,7 @@ class ApplicationServiceTransactionWorkerStore( may be empty. """ results = yield self.db.simple_select_list( - "application_services_state", dict(state=state), ["as_id"] + "application_services_state", {"state": state}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore as_list = self.get_app_services() @@ -158,7 +158,7 @@ class ApplicationServiceTransactionWorkerStore( """ result = yield self.db.simple_select_one( "application_services_state", - dict(as_id=service.id), + {"as_id": service.id}, ["state"], allow_none=True, desc="get_appservice_state", @@ -177,7 +177,7 @@ class ApplicationServiceTransactionWorkerStore( A Deferred which resolves when the state was set successfully. """ return self.db.simple_upsert( - "application_services_state", dict(as_id=service.id), dict(state=state) + "application_services_state", {"as_id": service.id}, {"state": state} ) def create_appservice_txn(self, service, events): @@ -253,13 +253,15 @@ class ApplicationServiceTransactionWorkerStore( self.db.simple_upsert_txn( txn, "application_services_state", - dict(as_id=service.id), - dict(last_txn=txn_id), + {"as_id": service.id}, + {"last_txn": txn_id}, ) # Delete txn self.db.simple_delete_txn( - txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id) + txn, + "application_services_txns", + {"txn_id": txn_id, "as_id": service.id}, ) return self.db.runInteraction( diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 13f4c9c72e..e1ccb27142 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -530,7 +530,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"])) for row in rows ) - return list( + return [ { "access_token": access_token, "ip": ip, @@ -538,7 +538,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): "last_seen": last_seen, } for (access_token, ip), (user_agent, last_seen) in iteritems(results) - ) + ] @wrap_as_background_process("prune_old_user_ips") async def _prune_old_user_ips(self): diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index b7617efb80..d55733a4cd 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -137,7 +137,7 @@ class DeviceWorkerStore(SQLBaseStore): # get the cross-signing keys of the users in the list, so that we can # determine which of the device changes were cross-signing keys - users = set(r[0] for r in updates) + users = {r[0] for r in updates} master_key_by_user = {} self_signing_key_by_user = {} for user in users: @@ -446,7 +446,7 @@ class DeviceWorkerStore(SQLBaseStore): a set of user_ids and results_map is a mapping of user_id -> device_id -> device_info """ - user_ids = set(user_id for user_id, _ in query_list) + user_ids = {user_id for user_id, _ in query_list} user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) # We go and check if any of the users need to have their device lists @@ -454,10 +454,9 @@ class DeviceWorkerStore(SQLBaseStore): users_needing_resync = yield self.get_user_ids_requiring_device_list_resync( user_ids ) - user_ids_in_cache = ( - set(user_id for user_id, stream_id in user_map.items() if stream_id) - - users_needing_resync - ) + user_ids_in_cache = { + user_id for user_id, stream_id in user_map.items() if stream_id + } - users_needing_resync user_ids_not_in_cache = user_ids - user_ids_in_cache results = {} @@ -604,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore): rows = yield self.db.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) - return set(user for row in rows for user in json.loads(row[0])) + return {user for row in rows for user in json.loads(row[0])} else: return set() diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index e551606f9d..001a53f9b4 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -680,11 +680,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): 'user_signing' for a user-signing key key (dict): the key data """ - # the cross-signing keys need to occupy the same namespace as devices, - # since signatures are identified by device ID. So add an entry to the - # device table to make sure that we don't have a collision with device - # IDs - # the 'key' dict will look something like: # { # "user_id": "@alice:example.com", @@ -701,16 +696,24 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): # The "keys" property must only have one entry, which will be the public # key, so we just grab the first value in there pubkey = next(iter(key["keys"].values())) - self.db.simple_insert_txn( - txn, - "devices", - values={ - "user_id": user_id, - "device_id": pubkey, - "display_name": key_type + " signing key", - "hidden": True, - }, - ) + + # The cross-signing keys need to occupy the same namespace as devices, + # since signatures are identified by device ID. So add an entry to the + # device table to make sure that we don't have a collision with device + # IDs. + # We only need to do this for local users, since remote servers should be + # responsible for checking this for their own users. + if self.hs.is_mine_id(user_id): + self.db.simple_insert_txn( + txn, + "devices", + values={ + "user_id": user_id, + "device_id": pubkey, + "display_name": key_type + " signing key", + "hidden": True, + }, + ) # and finally, store the key itself with self._cross_signing_id_gen.get_next() as stream_id: diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 60c67457b4..62d4e9f599 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -14,8 +14,8 @@ # limitations under the License. import itertools import logging +from typing import Dict, List, Optional, Set, Tuple -from six.moves import range from six.moves.queue import Empty, PriorityQueue from twisted.internet import defer @@ -27,6 +27,7 @@ from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.signatures import SignatureWorkerStore from synapse.storage.database import Database from synapse.util.caches.descriptors import cached +from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) @@ -46,21 +47,37 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_ids, include_given=include_given ).addCallback(self.get_events_as_list) - def get_auth_chain_ids(self, event_ids, include_given=False): + def get_auth_chain_ids( + self, + event_ids: List[str], + include_given: bool = False, + ignore_events: Optional[Set[str]] = None, + ): """Get auth events for given event_ids. The events *must* be state events. Args: - event_ids (list): state events - include_given (bool): include the given events in result + event_ids: state events + include_given: include the given events in result + ignore_events: Set of events to exclude from the returned auth + chain. This is useful if the caller will just discard the + given events anyway, and saves us from figuring out their auth + chains if not required. Returns: list of event_ids """ return self.db.runInteraction( - "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given + "get_auth_chain_ids", + self._get_auth_chain_ids_txn, + event_ids, + include_given, + ignore_events, ) - def _get_auth_chain_ids_txn(self, txn, event_ids, include_given): + def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events): + if ignore_events is None: + ignore_events = set() + if include_given: results = set(event_ids) else: @@ -71,15 +88,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas front = set(event_ids) while front: new_front = set() - front_list = list(front) - chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)] - for chunk in chunks: + for chunk in batch_iter(front, 100): clause, args = make_in_list_sql_clause( txn.database_engine, "event_id", chunk ) - txn.execute(base_sql + clause, list(args)) - new_front.update([r[0] for r in txn]) + txn.execute(base_sql + clause, args) + new_front.update(r[0] for r in txn) + new_front -= ignore_events new_front -= results front = new_front @@ -87,6 +103,154 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) + def get_auth_chain_difference(self, state_sets: List[Set[str]]): + """Given sets of state events figure out the auth chain difference (as + per state res v2 algorithm). + + This equivalent to fetching the full auth chain for each set of state + and returning the events that don't appear in each and every auth + chain. + + Returns: + Deferred[Set[str]] + """ + + return self.db.runInteraction( + "get_auth_chain_difference", + self._get_auth_chain_difference_txn, + state_sets, + ) + + def _get_auth_chain_difference_txn( + self, txn, state_sets: List[Set[str]] + ) -> Set[str]: + + # Algorithm Description + # ~~~~~~~~~~~~~~~~~~~~~ + # + # The idea here is to basically walk the auth graph of each state set in + # tandem, keeping track of which auth events are reachable by each state + # set. If we reach an auth event we've already visited (via a different + # state set) then we mark that auth event and all ancestors as reachable + # by the state set. This requires that we keep track of the auth chains + # in memory. + # + # Doing it in a such a way means that we can stop early if all auth + # events we're currently walking are reachable by all state sets. + # + # *Note*: We can't stop walking an event's auth chain if it is reachable + # by all state sets. This is because other auth chains we're walking + # might be reachable only via the original auth chain. For example, + # given the following auth chain: + # + # A -> C -> D -> E + # / / + # B -´---------´ + # + # and state sets {A} and {B} then walking the auth chains of A and B + # would immediately show that C is reachable by both. However, if we + # stopped at C then we'd only reach E via the auth chain of B and so E + # would errornously get included in the returned difference. + # + # The other thing that we do is limit the number of auth chains we walk + # at once, due to practical limits (i.e. we can only query the database + # with a limited set of parameters). We pick the auth chains we walk + # each iteration based on their depth, in the hope that events with a + # lower depth are likely reachable by those with higher depths. + # + # We could use any ordering that we believe would give a rough + # topological ordering, e.g. origin server timestamp. If the ordering + # chosen is not topological then the algorithm still produces the right + # result, but perhaps a bit more inefficiently. This is why it is safe + # to use "depth" here. + + initial_events = set(state_sets[0]).union(*state_sets[1:]) + + # Dict from events in auth chains to which sets *cannot* reach them. + # I.e. if the set is empty then all sets can reach the event. + event_to_missing_sets = { + event_id: {i for i, a in enumerate(state_sets) if event_id not in a} + for event_id in initial_events + } + + # We need to get the depth of the initial events for sorting purposes. + sql = """ + SELECT depth, event_id FROM events + WHERE %s + ORDER BY depth ASC + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", initial_events + ) + txn.execute(sql % (clause,), args) + + # The sorted list of events whose auth chains we should walk. + search = txn.fetchall() # type: List[Tuple[int, str]] + + # Map from event to its auth events + event_to_auth_events = {} # type: Dict[str, Set[str]] + + base_sql = """ + SELECT a.event_id, auth_id, depth + FROM event_auth AS a + INNER JOIN events AS e ON (e.event_id = a.auth_id) + WHERE + """ + + while search: + # Check whether all our current walks are reachable by all state + # sets. If so we can bail. + if all(not event_to_missing_sets[eid] for _, eid in search): + break + + # Fetch the auth events and their depths of the N last events we're + # currently walking + search, chunk = search[:-100], search[-100:] + clause, args = make_in_list_sql_clause( + txn.database_engine, "a.event_id", [e_id for _, e_id in chunk] + ) + txn.execute(base_sql + clause, args) + + for event_id, auth_event_id, auth_event_depth in txn: + event_to_auth_events.setdefault(event_id, set()).add(auth_event_id) + + sets = event_to_missing_sets.get(auth_event_id) + if sets is None: + # First time we're seeing this event, so we add it to the + # queue of things to fetch. + search.append((auth_event_depth, auth_event_id)) + + # Assume that this event is unreachable from any of the + # state sets until proven otherwise + sets = event_to_missing_sets[auth_event_id] = set( + range(len(state_sets)) + ) + else: + # We've previously seen this event, so look up its auth + # events and recursively mark all ancestors as reachable + # by the current event's state set. + a_ids = event_to_auth_events.get(auth_event_id) + while a_ids: + new_aids = set() + for a_id in a_ids: + event_to_missing_sets[a_id].intersection_update( + event_to_missing_sets[event_id] + ) + + b = event_to_auth_events.get(a_id) + if b: + new_aids.update(b) + + a_ids = new_aids + + # Mark that the auth event is reachable by the approriate sets. + sets.intersection_update(event_to_missing_sets[event_id]) + + search.sort() + + # Return all events where not all sets can reach them. + return {eid for eid, n in event_to_missing_sets.items() if n} + def get_oldest_events_in_room(self, room_id): return self.db.runInteraction( "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id @@ -410,7 +574,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas query, (room_id, event_id, False, limit - len(event_results)) ) - new_results = set(t[0] for t in txn) - seen_events + new_results = {t[0] for t in txn} - seen_events new_front |= new_results seen_events |= new_results diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 9988a6d3fc..8eed590929 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -608,6 +608,23 @@ class EventPushActionsWorkerStore(SQLBaseStore): return range_end + @defer.inlineCallbacks + def get_time_of_last_push_action_before(self, stream_ordering): + def f(txn): + sql = ( + "SELECT e.received_ts" + " FROM event_push_actions AS ep" + " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" + " WHERE ep.stream_ordering > ?" + " ORDER BY ep.stream_ordering ASC" + " LIMIT 1" + ) + txn.execute(sql, (stream_ordering,)) + return txn.fetchone() + + result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) + return result[0] if result else None + class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" @@ -736,23 +753,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore): return push_actions @defer.inlineCallbacks - def get_time_of_last_push_action_before(self, stream_ordering): - def f(txn): - sql = ( - "SELECT e.received_ts" - " FROM event_push_actions AS ep" - " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" - " WHERE ep.stream_ordering > ?" - " ORDER BY ep.stream_ordering ASC" - " LIMIT 1" - ) - txn.execute(sql, (stream_ordering,)) - return txn.fetchone() - - result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) - return result[0] if result else None - - @defer.inlineCallbacks def get_latest_push_action_stream_ordering(self): def f(txn): txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index c9d0d68c3a..d593ef47b8 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -145,7 +145,7 @@ class EventsStore( return txn.fetchall() res = yield self.db.runInteraction("read_forward_extremities", fetch) - self._current_forward_extremities_amount = c_counter(list(x[0] for x in res)) + self._current_forward_extremities_amount = c_counter([x[0] for x in res]) @_retry_on_integrity_error @defer.inlineCallbacks @@ -598,11 +598,11 @@ class EventsStore( # We find out which membership events we may have deleted # and which we have added, then we invlidate the caches for all # those users. - members_changed = set( + members_changed = { state_key for ev_type, state_key in itertools.chain(to_delete, to_insert) if ev_type == EventTypes.Member - ) + } for member in members_changed: txn.call_after( @@ -1168,7 +1168,11 @@ class EventsStore( and original_event.internal_metadata.is_redacted() ): # Redaction was allowed - pruned_json = encode_json(prune_event_dict(original_event.get_dict())) + pruned_json = encode_json( + prune_event_dict( + original_event.room_version, original_event.get_dict() + ) + ) else: # Redaction wasn't allowed pruned_json = None @@ -1615,7 +1619,7 @@ class EventsStore( """ ) - referenced_state_groups = set(sg for sg, in txn) + referenced_state_groups = {sg for sg, in txn} logger.info( "[purge] found %i referenced state groups", len(referenced_state_groups) ) @@ -1929,7 +1933,9 @@ class EventsStore( return # Prune the event's dict then convert it to JSON. - pruned_json = encode_json(prune_event_dict(event.get_dict())) + pruned_json = encode_json( + prune_event_dict(event.room_version, event.get_dict()) + ) # Update the event_json table to replace the event's JSON with the pruned # JSON. diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 5177b71016..f54c8b1ee0 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -402,7 +402,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): keyvalues={}, retcols=("room_id",), ) - room_ids = set(row["room_id"] for row in rows) + room_ids = {row["room_id"] for row in rows} for room_id in room_ids: txn.call_after( self.get_latest_event_ids_in_room.invalidate, (room_id,) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 7251e819f5..ca237c6f12 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -28,9 +28,12 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError -from synapse.api.room_versions import EventFormatVersions -from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.api.room_versions import ( + KNOWN_ROOM_VERSIONS, + EventFormatVersions, + RoomVersions, +) +from synapse.events import make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process @@ -494,9 +497,9 @@ class EventsWorkerStore(SQLBaseStore): """ with Measure(self._clock, "_fetch_event_list"): try: - events_to_fetch = set( + events_to_fetch = { event_id for events, _ in event_list for event_id in events - ) + } row_dict = self.db.new_transaction( conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch @@ -580,8 +583,49 @@ class EventsWorkerStore(SQLBaseStore): # of a event format version, so it must be a V1 event. format_version = EventFormatVersions.V1 - original_ev = event_type_from_format_version(format_version)( + room_version_id = row["room_version_id"] + + if not room_version_id: + # this should only happen for out-of-band membership events + if not internal_metadata.get("out_of_band_membership"): + logger.warning( + "Room %s for event %s is unknown", d["room_id"], event_id + ) + continue + + # take a wild stab at the room version based on the event format + if format_version == EventFormatVersions.V1: + room_version = RoomVersions.V1 + elif format_version == EventFormatVersions.V2: + room_version = RoomVersions.V3 + else: + room_version = RoomVersions.V5 + else: + room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not room_version: + logger.error( + "Event %s in room %s has unknown room version %s", + event_id, + d["room_id"], + room_version_id, + ) + continue + + if room_version.event_format != format_version: + logger.error( + "Event %s in room %s with version %s has wrong format: " + "expected %s, was %s", + event_id, + d["room_id"], + room_version_id, + room_version.event_format, + format_version, + ) + continue + + original_ev = make_event_from_dict( event_dict=d, + room_version=room_version, internal_metadata_dict=internal_metadata, rejected_reason=rejected_reason, ) @@ -661,6 +705,12 @@ class EventsWorkerStore(SQLBaseStore): of EventFormatVersions. 'None' means the event predates EventFormatVersions (so the event is format V1). + * room_version_id (str|None): The version of the room which contains the event. + Hopefully one of RoomVersions. + + Due to historical reasons, there may be a few events in the database which + do not have an associated room; in this case None will be returned here. + * rejected_reason (str|None): if the event was rejected, the reason why. @@ -676,17 +726,18 @@ class EventsWorkerStore(SQLBaseStore): """ event_dict = {} for evs in batch_iter(event_ids, 200): - sql = ( - "SELECT " - " e.event_id, " - " e.internal_metadata," - " e.json," - " e.format_version, " - " rej.reason " - " FROM event_json as e" - " LEFT JOIN rejections as rej USING (event_id)" - " WHERE " - ) + sql = """\ + SELECT + e.event_id, + e.internal_metadata, + e.json, + e.format_version, + r.room_version, + rej.reason + FROM event_json as e + LEFT JOIN rooms r USING (room_id) + LEFT JOIN rejections as rej USING (event_id) + WHERE """ clause, args = make_in_list_sql_clause( txn.database_engine, "e.event_id", evs @@ -701,7 +752,8 @@ class EventsWorkerStore(SQLBaseStore): "internal_metadata": row[1], "json": row[2], "format_version": row[3], - "rejected_reason": row[4], + "room_version_id": row[4], + "rejected_reason": row[5], "redactions": [], } @@ -804,7 +856,7 @@ class EventsWorkerStore(SQLBaseStore): desc="have_events_in_timeline", ) - return set(r["event_id"] for r in rows) + return {r["event_id"] for r in rows} @defer.inlineCallbacks def have_seen_events(self, event_ids): diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index 1507a14e09..925bc5691b 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -43,13 +43,40 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): def _count_users(txn): sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" - txn.execute(sql) (count,) = txn.fetchone() return count return self.db.runInteraction("count_users", _count_users) + @cached(num_args=0) + def get_monthly_active_count_by_service(self): + """Generates current count of monthly active users broken down by service. + A service is typically an appservice but also includes native matrix users. + Since the `monthly_active_users` table is populated from the `user_ips` table + `config.track_appservice_user_ips` must be set to `true` for this + method to return anything other than native matrix users. + + Returns: + Deferred[dict]: dict that includes a mapping between app_service_id + and the number of occurrences. + + """ + + def _count_users_by_service(txn): + sql = """ + SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0) + FROM monthly_active_users + LEFT JOIN users ON monthly_active_users.user_id=users.name + GROUP BY appservice_id; + """ + + txn.execute(sql) + result = txn.fetchall() + return dict(result) + + return self.db.runInteraction("count_users_by_service", _count_users_by_service) + @defer.inlineCallbacks def get_registered_reserved_users(self): """Of the reserved threepids defined in config, which are associated @@ -292,6 +319,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) self._invalidate_cache_and_stream( + txn, self.get_monthly_active_count_by_service, () + ) + self._invalidate_cache_and_stream( txn, self.user_last_seen_monthly_active, (user_id,) ) diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index e2673ae073..62ac88d9f2 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -276,21 +276,21 @@ class PushRulesWorkerStore( # We ignore app service users for now. This is so that we don't fill # up the `get_if_users_have_pushers` cache with AS entries that we # know don't have pushers, nor even read receipts. - local_users_in_room = set( + local_users_in_room = { u for u in users_in_room if self.hs.is_mine_id(u) and not self.get_if_app_services_interested_in_user(u) - ) + } # users in the room who have pushers need to get push rules run because # that's how their pushers work if_users_with_pushers = yield self.get_if_users_have_pushers( local_users_in_room, on_invalidate=cache_context.invalidate ) - user_ids = set( + user_ids = { uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - ) + } users_with_receipts = yield self.get_users_with_read_receipts_in_room( room_id, on_invalidate=cache_context.invalidate diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index 6b03233262..547b9d69cb 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -197,6 +197,84 @@ class PusherWorkerStore(SQLBaseStore): return result + @defer.inlineCallbacks + def update_pusher_last_stream_ordering( + self, app_id, pushkey, user_id, last_stream_ordering + ): + yield self.db.simple_update_one( + "pushers", + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + {"last_stream_ordering": last_stream_ordering}, + desc="update_pusher_last_stream_ordering", + ) + + @defer.inlineCallbacks + def update_pusher_last_stream_ordering_and_success( + self, app_id, pushkey, user_id, last_stream_ordering, last_success + ): + """Update the last stream ordering position we've processed up to for + the given pusher. + + Args: + app_id (str) + pushkey (str) + last_stream_ordering (int) + last_success (int) + + Returns: + Deferred[bool]: True if the pusher still exists; False if it has been deleted. + """ + updated = yield self.db.simple_update( + table="pushers", + keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + updatevalues={ + "last_stream_ordering": last_stream_ordering, + "last_success": last_success, + }, + desc="update_pusher_last_stream_ordering_and_success", + ) + + return bool(updated) + + @defer.inlineCallbacks + def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): + yield self.db.simple_update( + table="pushers", + keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + updatevalues={"failing_since": failing_since}, + desc="update_pusher_failing_since", + ) + + @defer.inlineCallbacks + def get_throttle_params_by_room(self, pusher_id): + res = yield self.db.simple_select_list( + "pusher_throttle", + {"pusher": pusher_id}, + ["room_id", "last_sent_ts", "throttle_ms"], + desc="get_throttle_params_by_room", + ) + + params_by_room = {} + for row in res: + params_by_room[row["room_id"]] = { + "last_sent_ts": row["last_sent_ts"], + "throttle_ms": row["throttle_ms"], + } + + return params_by_room + + @defer.inlineCallbacks + def set_throttle_params(self, pusher_id, room_id, params): + # no need to lock because `pusher_throttle` has a primary key on + # (pusher, room_id) so simple_upsert will retry + yield self.db.simple_upsert( + "pusher_throttle", + {"pusher": pusher_id, "room_id": room_id}, + params, + desc="set_throttle_params", + lock=False, + ) + class PusherStore(PusherWorkerStore): def get_pushers_stream_token(self): @@ -282,81 +360,3 @@ class PusherStore(PusherWorkerStore): with self._pushers_id_gen.get_next() as stream_id: yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id) - - @defer.inlineCallbacks - def update_pusher_last_stream_ordering( - self, app_id, pushkey, user_id, last_stream_ordering - ): - yield self.db.simple_update_one( - "pushers", - {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - {"last_stream_ordering": last_stream_ordering}, - desc="update_pusher_last_stream_ordering", - ) - - @defer.inlineCallbacks - def update_pusher_last_stream_ordering_and_success( - self, app_id, pushkey, user_id, last_stream_ordering, last_success - ): - """Update the last stream ordering position we've processed up to for - the given pusher. - - Args: - app_id (str) - pushkey (str) - last_stream_ordering (int) - last_success (int) - - Returns: - Deferred[bool]: True if the pusher still exists; False if it has been deleted. - """ - updated = yield self.db.simple_update( - table="pushers", - keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - updatevalues={ - "last_stream_ordering": last_stream_ordering, - "last_success": last_success, - }, - desc="update_pusher_last_stream_ordering_and_success", - ) - - return bool(updated) - - @defer.inlineCallbacks - def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self.db.simple_update( - table="pushers", - keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, - updatevalues={"failing_since": failing_since}, - desc="update_pusher_failing_since", - ) - - @defer.inlineCallbacks - def get_throttle_params_by_room(self, pusher_id): - res = yield self.db.simple_select_list( - "pusher_throttle", - {"pusher": pusher_id}, - ["room_id", "last_sent_ts", "throttle_ms"], - desc="get_throttle_params_by_room", - ) - - params_by_room = {} - for row in res: - params_by_room[row["room_id"]] = { - "last_sent_ts": row["last_sent_ts"], - "throttle_ms": row["throttle_ms"], - } - - return params_by_room - - @defer.inlineCallbacks - def set_throttle_params(self, pusher_id, room_id, params): - # no need to lock because `pusher_throttle` has a primary key on - # (pusher, room_id) so simple_upsert will retry - yield self.db.simple_upsert( - "pusher_throttle", - {"pusher": pusher_id, "room_id": room_id}, - params, - desc="set_throttle_params", - lock=False, - ) diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index 96e54d145e..0d932a0672 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -58,7 +58,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") - return set(r["user_id"] for r in receipts) + return {r["user_id"] for r in receipts} @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): @@ -283,7 +283,7 @@ class ReceiptsWorkerStore(SQLBaseStore): args.append(limit) txn.execute(sql, args) - return list(r[0:5] + (json.loads(r[5]),) for r in txn) + return [r[0:5] + (json.loads(r[5]),) for r in txn] return self.db.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 9a17e336ba..e6c10c6316 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -954,6 +954,23 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self.config = hs.config + async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion): + """Ensure that the room is stored in the table + + Called when we join a room over federation, and overwrites any room version + currently in the table. + """ + await self.db.simple_upsert( + desc="upsert_room_on_join", + table="rooms", + keyvalues={"room_id": room_id}, + values={"room_version": room_version.identifier}, + insertion_values={"is_public": False, "creator": ""}, + # rooms has a unique constraint on room_id, so no need to lock when doing an + # emulated upsert. + lock=False, + ) + @defer.inlineCallbacks def store_room( self, @@ -1003,6 +1020,26 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") + async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion): + """ + When we receive an invite over federation, store the version of the room if we + don't already know the room version. + """ + await self.db.simple_upsert( + desc="maybe_store_room_on_invite", + table="rooms", + keyvalues={"room_id": room_id}, + values={}, + insertion_values={ + "room_version": room_version.identifier, + "is_public": False, + "creator": "", + }, + # rooms has a unique constraint on room_id, so no need to lock when doing an + # emulated upsert. + lock=False, + ) + @defer.inlineCallbacks def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index d5ced05701..d5bd0cb5cf 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -465,7 +465,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql % (clause,), args) - return set(row[0] for row in txn) + return {row[0] for row in txn} return await self.db.runInteraction( "get_users_server_still_shares_room_with", @@ -826,7 +826,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): GROUP BY room_id, user_id; """ txn.execute(sql, (user_id,)) - return set(row[0] for row in txn if row[1] == 0) + return {row[0] for row in txn if row[1] == 0} return self.db.runInteraction( "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres new file mode 100644 index 0000000000..92aaadde0d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres @@ -0,0 +1,39 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- When we first added the room_version column to the rooms table, it was populated from +-- the current_state_events table. However, there was an issue causing a background +-- update to clean up the current_state_events table for rooms where the server is no +-- longer participating, before that column could be populated. Therefore, some rooms had +-- a NULL room_version. + +-- The rooms_version_column_2.sql.* delta files were introduced to make the populating +-- synchronous instead of running it in a background update, which fixed this issue. +-- However, all of the instances of Synapse installed or updated in the meantime got +-- their rooms table corrupted with NULL room_versions. + +-- This query fishes out the room versions from the create event using the state_events +-- table instead of the current_state_events one, as the former still have all of the +-- create events. + +UPDATE rooms SET room_version=( + SELECT COALESCE(json::json->'content'->>'room_version','1') + FROM state_events se INNER JOIN event_json ej USING (event_id) + WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key='' + LIMIT 1 +) WHERE rooms.room_version IS NULL; + +-- see also rooms_version_column_3.sql.sqlite which has a copy of the above query, using +-- sqlite syntax for the json extraction. diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite new file mode 100644 index 0000000000..e19dab97cb --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite @@ -0,0 +1,23 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- see rooms_version_column_3.sql.postgres for details of what's going on here. + +UPDATE rooms SET room_version=( + SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1') + FROM state_events se INNER JOIN event_json ej USING (event_id) + WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key='' + LIMIT 1 +) WHERE rooms.room_version IS NULL; diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md index bbd3f18604..c00f287190 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/README.md +++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md @@ -1,13 +1,21 @@ -# Building full schema dumps +# Synapse Database Schemas -These schemas need to be made from a database that has had all background updates run. +These schemas are used as a basis to create brand new Synapse databases, on both +SQLite3 and Postgres. -To do so, use `scripts-dev/make_full_schema.sh`. This will produce -`full.sql.postgres ` and `full.sql.sqlite` files. +## Building full schema dumps + +If you want to recreate these schemas, they need to be made from a database that +has had all background updates run. + +To do so, use `scripts-dev/make_full_schema.sh`. This will produce new +`full.sql.postgres ` and `full.sql.sqlite` files. Ensure postgres is installed and your user has the ability to run bash commands -such as `createdb`. +such as `createdb`, then call + + ./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/ -``` -./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/ -``` +There are currently two folders with full-schema snapshots. `16` is a snapshot +from 2015, for historical reference. The other contains the most recent full +schema snapshot. diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 3d34103e67..3a3b9a8e72 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -321,7 +321,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): desc="get_referenced_state_groups", ) - return set(row["state_group"] for row in rows) + return {row["state_group"] for row in rows} class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): @@ -367,7 +367,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): """ txn.execute(sql, (last_room_id, batch_size)) - room_ids = list(row[0] for row in txn) + room_ids = [row[0] for row in txn] if not room_ids: return True, set() @@ -384,7 +384,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name)) - joined_room_ids = set(row[0] for row in txn) + joined_room_ids = {row[0] for row in txn} left_rooms = set(room_ids) - joined_room_ids @@ -404,7 +404,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): retcols=("state_key",), ) - potentially_left_users = set(row["state_key"] for row in rows) + potentially_left_users = {row["state_key"] for row in rows} # Now lets actually delete the rooms from the DB. self.db.simple_delete_many_txn( diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py index 12c982cb26..725e12507f 100644 --- a/synapse/storage/data_stores/main/state_deltas.py +++ b/synapse/storage/data_stores/main/state_deltas.py @@ -15,6 +15,8 @@ import logging +from twisted.internet import defer + from synapse.storage._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -56,7 +58,7 @@ class StateDeltasStore(SQLBaseStore): # if the CSDs haven't changed between prev_stream_id and now, we # know for certain that they haven't changed between prev_stream_id and # max_stream_id. - return max_stream_id, [] + return defer.succeed((max_stream_id, [])) def get_current_state_deltas_txn(txn): # First we calculate the max stream id that will give us less than diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 056b25b13a..ada5cce6c2 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -346,11 +346,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): from_key (str): The room_key portion of a StreamToken """ from_key = RoomStreamToken.parse_stream_token(from_key).stream - return set( + return { room_id for room_id in room_ids if self._events_stream_cache.has_entity_changed(room_id, from_key) - ) + } @defer.inlineCallbacks def get_room_events_stream_for_room( @@ -679,11 +679,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) events_before = yield self.get_events_as_list( - [e for e in results["before"]["event_ids"]], get_prev_content=True + list(results["before"]["event_ids"]), get_prev_content=True ) events_after = yield self.get_events_as_list( - [e for e in results["after"]["event_ids"]], get_prev_content=True + list(results["after"]["event_ids"]), get_prev_content=True ) return { diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py index af8025bc17..ec6b8a4ffd 100644 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ b/synapse/storage/data_stores/main/user_erasure_store.py @@ -63,9 +63,9 @@ class UserErasureWorkerStore(SQLBaseStore): retcols=("user_id",), desc="are_users_erased", ) - erased_users = set(row["user_id"] for row in rows) + erased_users = {row["user_id"] for row in rows} - res = dict((u, u in erased_users) for u in user_ids) + res = {u: u in erased_users for u in user_ids} return res diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index c4ee9b7ccb..57a5267663 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -520,11 +520,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): retcols=("state_group",), ) - remaining_state_groups = set( + remaining_state_groups = { row["state_group"] for row in rows if row["state_group"] not in state_groups_to_delete - ) + } logger.info( "[purge] de-delta-ing %i remaining state groups", diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 3eeb2f7c04..e61595336c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -15,9 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import sys import time -from typing import Iterable, Tuple +from time import monotonic as monotonic_time +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple from six import iteritems, iterkeys, itervalues from six.moves import intern, range @@ -29,27 +29,21 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig -from synapse.logging.context import LoggingContext, make_deferred_yieldable +from synapse.logging.context import ( + LoggingContext, + LoggingContextOrSentinel, + make_deferred_yieldable, +) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.types import Connection, Cursor from synapse.util.stringutils import exception_to_unicode -# import a function which will return a monotonic time, in seconds -try: - # on python 3, use time.monotonic, since time.clock can go backwards - from time import monotonic as monotonic_time -except ImportError: - # ... but python 2 doesn't have it - from time import clock as monotonic_time - logger = logging.getLogger(__name__) -try: - MAX_TXN_ID = sys.maxint - 1 -except AttributeError: - # python 3 does not have a maximum int value - MAX_TXN_ID = 2 ** 63 - 1 +# python 3 does not have a maximum int value +MAX_TXN_ID = 2 ** 63 - 1 sql_logger = logging.getLogger("synapse.storage.SQL") transaction_logger = logging.getLogger("synapse.storage.txn") @@ -77,7 +71,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { def make_pool( - reactor, db_config: DatabaseConnectionConfig, engine + reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine ) -> adbapi.ConnectionPool: """Get the connection pool for the database. """ @@ -90,7 +84,9 @@ def make_pool( ) -def make_conn(db_config: DatabaseConnectionConfig, engine): +def make_conn( + db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine +) -> Connection: """Make a new connection to the database and return it. Returns: @@ -107,20 +103,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine): return db_conn -class LoggingTransaction(object): +# The type of entry which goes on our after_callbacks and exception_callbacks lists. +# +# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so +# that mypy sees the type but the runtime python doesn't. +_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]] + + +class LoggingTransaction: """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() method. Args: txn: The database transcation object to wrap. - name (str): The name of this transactions for logging. - database_engine (Sqlite3Engine|PostgresEngine) - after_callbacks(list|None): A list that callbacks will be appended to + name: The name of this transactions for logging. + database_engine + after_callbacks: A list that callbacks will be appended to that have been added by `call_after` which should be run on successful completion of the transaction. None indicates that no callbacks should be allowed to be scheduled to run. - exception_callbacks(list|None): A list that callbacks will be appended + exception_callbacks: A list that callbacks will be appended to that have been added by `call_on_exception` which should be run if transaction ends with an error. None indicates that no callbacks should be allowed to be scheduled to run. @@ -135,46 +138,67 @@ class LoggingTransaction(object): ] def __init__( - self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None + self, + txn: Cursor, + name: str, + database_engine: BaseDatabaseEngine, + after_callbacks: Optional[List[_CallbackListEntry]] = None, + exception_callbacks: Optional[List[_CallbackListEntry]] = None, ): - object.__setattr__(self, "txn", txn) - object.__setattr__(self, "name", name) - object.__setattr__(self, "database_engine", database_engine) - object.__setattr__(self, "after_callbacks", after_callbacks) - object.__setattr__(self, "exception_callbacks", exception_callbacks) + self.txn = txn + self.name = name + self.database_engine = database_engine + self.after_callbacks = after_callbacks + self.exception_callbacks = exception_callbacks - def call_after(self, callback, *args, **kwargs): + def call_after(self, callback: "Callable[..., None]", *args, **kwargs): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. """ + # if self.after_callbacks is None, that means that whatever constructed the + # LoggingTransaction isn't expecting there to be any callbacks; assert that + # is not the case. + assert self.after_callbacks is not None self.after_callbacks.append((callback, args, kwargs)) - def call_on_exception(self, callback, *args, **kwargs): + def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs): + # if self.exception_callbacks is None, that means that whatever constructed the + # LoggingTransaction isn't expecting there to be any callbacks; assert that + # is not the case. + assert self.exception_callbacks is not None self.exception_callbacks.append((callback, args, kwargs)) - def __getattr__(self, name): - return getattr(self.txn, name) + def fetchall(self) -> List[Tuple]: + return self.txn.fetchall() - def __setattr__(self, name, value): - setattr(self.txn, name, value) + def fetchone(self) -> Tuple: + return self.txn.fetchone() - def __iter__(self): + def __iter__(self) -> Iterator[Tuple]: return self.txn.__iter__() + @property + def rowcount(self) -> int: + return self.txn.rowcount + + @property + def description(self) -> Any: + return self.txn.description + def execute_batch(self, sql, args): if isinstance(self.database_engine, PostgresEngine): - from psycopg2.extras import execute_batch + from psycopg2.extras import execute_batch # type: ignore self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) else: for val in args: self.execute(sql, val) - def execute(self, sql, *args): + def execute(self, sql: str, *args: Any): self._do_execute(self.txn.execute, sql, *args) - def executemany(self, sql, *args): + def executemany(self, sql: str, *args: Any): self._do_execute(self.txn.executemany, sql, *args) def _make_sql_one_line(self, sql): @@ -207,6 +231,9 @@ class LoggingTransaction(object): sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) sql_query_timer.labels(sql.split()[0]).observe(secs) + def close(self): + self.txn.close() + class PerformanceCounters(object): def __init__(self): @@ -251,7 +278,9 @@ class Database(object): _TXN_ID = 0 - def __init__(self, hs, database_config: DatabaseConnectionConfig, engine): + def __init__( + self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + ): self.hs = hs self._clock = hs.get_clock() self._database_config = database_config @@ -259,9 +288,9 @@ class Database(object): self.updates = BackgroundUpdater(hs, self) - self._previous_txn_total_time = 0 - self._current_txn_total_time = 0 - self._previous_loop_ts = 0 + self._previous_txn_total_time = 0.0 + self._current_txn_total_time = 0.0 + self._previous_loop_ts = 0.0 # TODO(paul): These can eventually be removed once the metrics code # is running in mainline, and we have some nice monitoring frontends @@ -463,23 +492,23 @@ class Database(object): sql_txn_timer.labels(desc).observe(duration) @defer.inlineCallbacks - def runInteraction(self, desc, func, *args, **kwargs): + def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any): """Starts a transaction on the database and runs a given function Arguments: - desc (str): description of the transaction, for logging and metrics - func (func): callback function, which will be called with a + desc: description of the transaction, for logging and metrics + func: callback function, which will be called with a database transaction (twisted.enterprise.adbapi.Transaction) as its first argument, followed by `args` and `kwargs`. - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` + args: positional args to pass to `func` + kwargs: named args to pass to `func` Returns: Deferred: The result of func """ - after_callbacks = [] - exception_callbacks = [] + after_callbacks = [] # type: List[_CallbackListEntry] + exception_callbacks = [] # type: List[_CallbackListEntry] if LoggingContext.current_context() == LoggingContext.sentinel: logger.warning("Starting db txn '%s' from sentinel context", desc) @@ -505,20 +534,22 @@ class Database(object): return result @defer.inlineCallbacks - def runWithConnection(self, func, *args, **kwargs): + def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any): """Wraps the .runWithConnection() method on the underlying db_pool. Arguments: - func (func): callback function, which will be called with a + func: callback function, which will be called with a database connection (twisted.enterprise.adbapi.Connection) as its first argument, followed by `args` and `kwargs`. - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` + args: positional args to pass to `func` + kwargs: named args to pass to `func` Returns: Deferred: The result of func """ - parent_context = LoggingContext.current_context() + parent_context = ( + LoggingContext.current_context() + ) # type: Optional[LoggingContextOrSentinel] if parent_context == LoggingContext.sentinel: logger.warning( "Starting db connection from sentinel context: metrics will be lost" @@ -554,8 +585,8 @@ class Database(object): Returns: A list of dicts where the key is the column header. """ - col_headers = list(intern(str(column[0])) for column in cursor.description) - results = list(dict(zip(col_headers, row)) for row in cursor) + col_headers = [intern(str(column[0])) for column in cursor.description] + results = [dict(zip(col_headers, row)) for row in cursor] return results def execute(self, desc, decoder, query, *args): @@ -800,7 +831,7 @@ class Database(object): return False # We didn't find any existing rows, so insert a new one - allvalues = {} + allvalues = {} # type: Dict[str, Any] allvalues.update(keyvalues) allvalues.update(values) allvalues.update(insertion_values) @@ -829,7 +860,7 @@ class Database(object): Returns: None """ - allvalues = {} + allvalues = {} # type: Dict[str, Any] allvalues.update(keyvalues) allvalues.update(insertion_values) @@ -916,7 +947,7 @@ class Database(object): Returns: None """ - allnames = [] + allnames = [] # type: List[str] allnames.extend(key_names) allnames.extend(value_names) @@ -1100,7 +1131,7 @@ class Database(object): keyvalues : dict of column names and values to select the rows with retcols : list of strings giving the names of the columns to return """ - results = [] + results = [] # type: List[Dict[str, Any]] if not iterable: return results @@ -1439,7 +1470,7 @@ class Database(object): raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") where_clause = "WHERE " if filters or keyvalues else "" - arg_list = [] + arg_list = [] # type: List[Any] if filters: where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters) arg_list += list(filters.values()) @@ -1504,7 +1535,7 @@ class Database(object): def make_in_list_sql_clause( database_engine, column: str, iterable: Iterable -) -> Tuple[str, Iterable]: +) -> Tuple[str, list]: """Returns an SQL clause that checks the given column is in the iterable. On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 9d2d519922..035f9ea6e9 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -12,29 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import importlib import platform -from ._base import IncorrectDatabaseSetup +from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup from .postgres import PostgresEngine from .sqlite import Sqlite3Engine -SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine} - -def create_engine(database_config): +def create_engine(database_config) -> BaseDatabaseEngine: name = database_config["name"] - engine_class = SUPPORTED_MODULE.get(name, None) - if engine_class: + if name == "sqlite3": + import sqlite3 + + return Sqlite3Engine(sqlite3, database_config) + + if name == "psycopg2": # pypy requires psycopg2cffi rather than psycopg2 - if name == "psycopg2" and platform.python_implementation() == "PyPy": - name = "psycopg2cffi" - module = importlib.import_module(name) - return engine_class(module, database_config) + if platform.python_implementation() == "PyPy": + import psycopg2cffi as psycopg2 # type: ignore + else: + import psycopg2 # type: ignore + + return PostgresEngine(psycopg2, database_config) raise RuntimeError("Unsupported database engine '%s'" % (name,)) -__all__ = ["create_engine", "IncorrectDatabaseSetup"] +__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"] diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index ec5a4d198b..ab0bbe4bd3 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -12,7 +12,94 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc +from typing import Generic, TypeVar + +from synapse.storage.types import Connection class IncorrectDatabaseSetup(RuntimeError): pass + + +ConnectionType = TypeVar("ConnectionType", bound=Connection) + + +class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): + def __init__(self, module, database_config: dict): + self.module = module + + @property + @abc.abstractmethod + def single_threaded(self) -> bool: + ... + + @property + @abc.abstractmethod + def can_native_upsert(self) -> bool: + """ + Do we support native UPSERTs? + """ + ... + + @property + @abc.abstractmethod + def supports_tuple_comparison(self) -> bool: + """ + Do we support comparing tuples, i.e. `(a, b) > (c, d)`? + """ + ... + + @property + @abc.abstractmethod + def supports_using_any_list(self) -> bool: + """ + Do we support using `a = ANY(?)` and passing a list + """ + ... + + @abc.abstractmethod + def check_database( + self, db_conn: ConnectionType, allow_outdated_version: bool = False + ) -> None: + ... + + @abc.abstractmethod + def check_new_database(self, txn) -> None: + """Gets called when setting up a brand new database. This allows us to + apply stricter checks on new databases versus existing database. + """ + ... + + @abc.abstractmethod + def convert_param_style(self, sql: str) -> str: + ... + + @abc.abstractmethod + def on_new_connection(self, db_conn: ConnectionType) -> None: + ... + + @abc.abstractmethod + def is_deadlock(self, error: Exception) -> bool: + ... + + @abc.abstractmethod + def is_connection_closed(self, conn: ConnectionType) -> bool: + ... + + @abc.abstractmethod + def lock_table(self, txn, table: str) -> None: + ... + + @abc.abstractmethod + def get_next_state_group_id(self, txn) -> int: + """Returns an int that can be used as a new state_group ID + """ + ... + + @property + @abc.abstractmethod + def server_version(self) -> str: + """Gets a string giving the server version. For example: '3.22.0' + """ + ... diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a077345960..6c7d08a6f2 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -15,16 +15,14 @@ import logging -from ._base import IncorrectDatabaseSetup +from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup logger = logging.getLogger(__name__) -class PostgresEngine(object): - single_threaded = False - +class PostgresEngine(BaseDatabaseEngine): def __init__(self, database_module, database_config): - self.module = database_module + super().__init__(database_module, database_config) self.module.extensions.register_type(self.module.extensions.UNICODE) # Disables passing `bytes` to txn.execute, c.f. #6186. If you do @@ -36,6 +34,10 @@ class PostgresEngine(object): self.synchronous_commit = database_config.get("synchronous_commit", True) self._version = None # unknown as yet + @property + def single_threaded(self) -> bool: + return False + def check_database(self, db_conn, allow_outdated_version: bool = False): # Get the version of PostgreSQL that we're using. As per the psycopg2 # docs: The number is formed by converting the major, minor, and @@ -53,7 +55,7 @@ class PostgresEngine(object): if rows and rows[0][0] != "UTF8": raise IncorrectDatabaseSetup( "Database has incorrect encoding: '%s' instead of 'UTF8'\n" - "See docs/postgres.rst for more information." % (rows[0][0],) + "See docs/postgres.md for more information." % (rows[0][0],) ) txn.execute( @@ -62,12 +64,16 @@ class PostgresEngine(object): collation, ctype = txn.fetchone() if collation != "C": logger.warning( - "Database has incorrect collation of %r. Should be 'C'", collation + "Database has incorrect collation of %r. Should be 'C'\n" + "See docs/postgres.md for more information.", + collation, ) if ctype != "C": logger.warning( - "Database has incorrect ctype of %r. Should be 'C'", ctype + "Database has incorrect ctype of %r. Should be 'C'\n" + "See docs/postgres.md for more information.", + ctype, ) def check_new_database(self, txn): diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 641e490697..2bfeefd54e 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -12,16 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import sqlite3 import struct import threading +from synapse.storage.engines import BaseDatabaseEngine -class Sqlite3Engine(object): - single_threaded = True +class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]): def __init__(self, database_module, database_config): - self.module = database_module + super().__init__(database_module, database_config) database = database_config.get("args", {}).get("database") self._is_in_memory = database in (None, ":memory:",) @@ -32,6 +32,10 @@ class Sqlite3Engine(object): self._current_state_group_id_lock = threading.Lock() @property + def single_threaded(self) -> bool: + return True + + @property def can_native_upsert(self): """ Do we support native UPSERTs? This requires SQLite3 3.24+, plus some @@ -68,7 +72,6 @@ class Sqlite3Engine(object): return sql def on_new_connection(self, db_conn): - # We need to import here to avoid an import loop. from synapse.storage.prepare_database import prepare_database diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index b950550f23..0f9ac1cf09 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -602,14 +602,14 @@ class EventsPersistenceStorage(object): event_id_to_state_group.update(event_to_groups) # State groups of old_latest_event_ids - old_state_groups = set( + old_state_groups = { event_id_to_state_group[evid] for evid in old_latest_event_ids - ) + } # State groups of new_latest_event_ids - new_state_groups = set( + new_state_groups = { event_id_to_state_group[evid] for evid in new_latest_event_ids - ) + } # If they old and new groups are the same then we don't need to do # anything. diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index c285ef52a0..6cb7d4b922 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -278,13 +278,17 @@ def _upgrade_existing_database( the current_version wasn't generated by applying those delta files. database_engine (DatabaseEngine) config (synapse.config.homeserver.HomeServerConfig|None): - application config, or None if we are connecting to an existing - database which we expect to be configured already + None if we are initialising a blank database, otherwise the application + config data_stores (list[str]): The names of the data stores to instantiate on the given database. is_empty (bool): Is this a blank database? I.e. do we need to run the upgrade portions of the delta scripts. """ + if is_empty: + assert not applied_delta_files + else: + assert config if current_version > SCHEMA_VERSION: raise ValueError( @@ -292,6 +296,13 @@ def _upgrade_existing_database( + "new for the server to understand" ) + # some of the deltas assume that config.server_name is set correctly, so now + # is a good time to run the sanity check. + if not is_empty and "main" in data_stores: + from synapse.storage.data_stores.main import check_database_before_upgrade + + check_database_before_upgrade(cur, database_engine, config) + start_ver = current_version if not upgraded: start_ver += 1 @@ -345,9 +356,9 @@ def _upgrade_existing_database( "Could not open delta dir for version %d: %s" % (v, directory) ) - duplicates = set( + duplicates = { file_name for file_name, count in file_name_counter.items() if count > 1 - ) + } if duplicates: # We don't support using the same file name in the same delta version. raise PrepareDatabaseException( @@ -454,7 +465,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) ), (modname,), ) - applied_deltas = set(d for d, in cur) + applied_deltas = {d for d, in cur} for (name, stream) in names_and_streams: if name in applied_deltas: continue diff --git a/synapse/storage/types.py b/synapse/storage/types.py new file mode 100644 index 0000000000..daff81c5ee --- /dev/null +++ b/synapse/storage/types.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Iterable, Iterator, List, Tuple + +from typing_extensions import Protocol + + +""" +Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 +""" + + +class Cursor(Protocol): + def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any: + ... + + def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any: + ... + + def fetchall(self) -> List[Tuple]: + ... + + def fetchone(self) -> Tuple: + ... + + @property + def description(self) -> Any: + return None + + @property + def rowcount(self) -> int: + return 0 + + def __iter__(self) -> Iterator[Tuple]: + ... + + def close(self) -> None: + ... + + +class Connection(Protocol): + def cursor(self) -> Cursor: + ... + + def close(self) -> None: + ... + + def commit(self) -> None: + ... + + def rollback(self, *args, **kwargs) -> None: + ... |