summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/data_stores/main/devices.py24
-rw-r--r--synapse/storage/data_stores/main/event_federation.py150
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py34
-rw-r--r--synapse/storage/data_stores/main/events.py10
-rw-r--r--synapse/storage/data_stores/main/events_worker.py84
-rw-r--r--synapse/storage/data_stores/main/monthly_active_users.py32
-rw-r--r--synapse/storage/data_stores/main/registration.py16
-rw-r--r--synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres39
-rw-r--r--synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite23
-rw-r--r--synapse/storage/data_stores/main/state_deltas.py4
-rw-r--r--synapse/storage/database.py153
-rw-r--r--synapse/storage/engines/__init__.py28
-rw-r--r--synapse/storage/engines/_base.py87
-rw-r--r--synapse/storage/engines/postgres.py12
-rw-r--r--synapse/storage/engines/sqlite.py16
-rw-r--r--synapse/storage/types.py65
16 files changed, 639 insertions, 138 deletions
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py

index d55733a4cd..8af5f7de54 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py
@@ -317,14 +317,16 @@ class DeviceWorkerStore(SQLBaseStore): key_json = device.get("key_json", None) if key_json: result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) else: result["deleted"] = True @@ -525,14 +527,16 @@ class DeviceWorkerStore(SQLBaseStore): key_json = device.get("key_json", None) if key_json: result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) results.append(result) diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 49a7b8b433..62d4e9f599 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py
@@ -14,7 +14,7 @@ # limitations under the License. import itertools import logging -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple from six.moves.queue import Empty, PriorityQueue @@ -103,6 +103,154 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) + def get_auth_chain_difference(self, state_sets: List[Set[str]]): + """Given sets of state events figure out the auth chain difference (as + per state res v2 algorithm). + + This equivalent to fetching the full auth chain for each set of state + and returning the events that don't appear in each and every auth + chain. + + Returns: + Deferred[Set[str]] + """ + + return self.db.runInteraction( + "get_auth_chain_difference", + self._get_auth_chain_difference_txn, + state_sets, + ) + + def _get_auth_chain_difference_txn( + self, txn, state_sets: List[Set[str]] + ) -> Set[str]: + + # Algorithm Description + # ~~~~~~~~~~~~~~~~~~~~~ + # + # The idea here is to basically walk the auth graph of each state set in + # tandem, keeping track of which auth events are reachable by each state + # set. If we reach an auth event we've already visited (via a different + # state set) then we mark that auth event and all ancestors as reachable + # by the state set. This requires that we keep track of the auth chains + # in memory. + # + # Doing it in a such a way means that we can stop early if all auth + # events we're currently walking are reachable by all state sets. + # + # *Note*: We can't stop walking an event's auth chain if it is reachable + # by all state sets. This is because other auth chains we're walking + # might be reachable only via the original auth chain. For example, + # given the following auth chain: + # + # A -> C -> D -> E + # / / + # B -´---------´ + # + # and state sets {A} and {B} then walking the auth chains of A and B + # would immediately show that C is reachable by both. However, if we + # stopped at C then we'd only reach E via the auth chain of B and so E + # would errornously get included in the returned difference. + # + # The other thing that we do is limit the number of auth chains we walk + # at once, due to practical limits (i.e. we can only query the database + # with a limited set of parameters). We pick the auth chains we walk + # each iteration based on their depth, in the hope that events with a + # lower depth are likely reachable by those with higher depths. + # + # We could use any ordering that we believe would give a rough + # topological ordering, e.g. origin server timestamp. If the ordering + # chosen is not topological then the algorithm still produces the right + # result, but perhaps a bit more inefficiently. This is why it is safe + # to use "depth" here. + + initial_events = set(state_sets[0]).union(*state_sets[1:]) + + # Dict from events in auth chains to which sets *cannot* reach them. + # I.e. if the set is empty then all sets can reach the event. + event_to_missing_sets = { + event_id: {i for i, a in enumerate(state_sets) if event_id not in a} + for event_id in initial_events + } + + # We need to get the depth of the initial events for sorting purposes. + sql = """ + SELECT depth, event_id FROM events + WHERE %s + ORDER BY depth ASC + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", initial_events + ) + txn.execute(sql % (clause,), args) + + # The sorted list of events whose auth chains we should walk. + search = txn.fetchall() # type: List[Tuple[int, str]] + + # Map from event to its auth events + event_to_auth_events = {} # type: Dict[str, Set[str]] + + base_sql = """ + SELECT a.event_id, auth_id, depth + FROM event_auth AS a + INNER JOIN events AS e ON (e.event_id = a.auth_id) + WHERE + """ + + while search: + # Check whether all our current walks are reachable by all state + # sets. If so we can bail. + if all(not event_to_missing_sets[eid] for _, eid in search): + break + + # Fetch the auth events and their depths of the N last events we're + # currently walking + search, chunk = search[:-100], search[-100:] + clause, args = make_in_list_sql_clause( + txn.database_engine, "a.event_id", [e_id for _, e_id in chunk] + ) + txn.execute(base_sql + clause, args) + + for event_id, auth_event_id, auth_event_depth in txn: + event_to_auth_events.setdefault(event_id, set()).add(auth_event_id) + + sets = event_to_missing_sets.get(auth_event_id) + if sets is None: + # First time we're seeing this event, so we add it to the + # queue of things to fetch. + search.append((auth_event_depth, auth_event_id)) + + # Assume that this event is unreachable from any of the + # state sets until proven otherwise + sets = event_to_missing_sets[auth_event_id] = set( + range(len(state_sets)) + ) + else: + # We've previously seen this event, so look up its auth + # events and recursively mark all ancestors as reachable + # by the current event's state set. + a_ids = event_to_auth_events.get(auth_event_id) + while a_ids: + new_aids = set() + for a_id in a_ids: + event_to_missing_sets[a_id].intersection_update( + event_to_missing_sets[event_id] + ) + + b = event_to_auth_events.get(a_id) + if b: + new_aids.update(b) + + a_ids = new_aids + + # Mark that the auth event is reachable by the approriate sets. + sets.intersection_update(event_to_missing_sets[event_id]) + + search.sort() + + # Return all events where not all sets can reach them. + return {eid for eid, n in event_to_missing_sets.items() if n} + def get_oldest_events_in_room(self, room_id): return self.db.runInteraction( "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 9988a6d3fc..8eed590929 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -608,6 +608,23 @@ class EventPushActionsWorkerStore(SQLBaseStore): return range_end + @defer.inlineCallbacks + def get_time_of_last_push_action_before(self, stream_ordering): + def f(txn): + sql = ( + "SELECT e.received_ts" + " FROM event_push_actions AS ep" + " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" + " WHERE ep.stream_ordering > ?" + " ORDER BY ep.stream_ordering ASC" + " LIMIT 1" + ) + txn.execute(sql, (stream_ordering,)) + return txn.fetchone() + + result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) + return result[0] if result else None + class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" @@ -736,23 +753,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore): return push_actions @defer.inlineCallbacks - def get_time_of_last_push_action_before(self, stream_ordering): - def f(txn): - sql = ( - "SELECT e.received_ts" - " FROM event_push_actions AS ep" - " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" - " WHERE ep.stream_ordering > ?" - " ORDER BY ep.stream_ordering ASC" - " LIMIT 1" - ) - txn.execute(sql, (stream_ordering,)) - return txn.fetchone() - - result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) - return result[0] if result else None - - @defer.inlineCallbacks def get_latest_push_action_stream_ordering(self): def f(txn): txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 8ae23df00a..d593ef47b8 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py
@@ -1168,7 +1168,11 @@ class EventsStore( and original_event.internal_metadata.is_redacted() ): # Redaction was allowed - pruned_json = encode_json(prune_event_dict(original_event.get_dict())) + pruned_json = encode_json( + prune_event_dict( + original_event.room_version, original_event.get_dict() + ) + ) else: # Redaction wasn't allowed pruned_json = None @@ -1929,7 +1933,9 @@ class EventsStore( return # Prune the event's dict then convert it to JSON. - pruned_json = encode_json(prune_event_dict(event.get_dict())) + pruned_json = encode_json( + prune_event_dict(event.room_version, event.get_dict()) + ) # Update the event_json table to replace the event's JSON with the pruned # JSON. diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 3704995bad..8d2bfeb4a3 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py
@@ -28,9 +28,12 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError -from synapse.api.room_versions import EventFormatVersions -from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.api.room_versions import ( + KNOWN_ROOM_VERSIONS, + EventFormatVersions, + RoomVersions, +) +from synapse.events import make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process @@ -580,8 +583,49 @@ class EventsWorkerStore(SQLBaseStore): # of a event format version, so it must be a V1 event. format_version = EventFormatVersions.V1 - original_ev = event_type_from_format_version(format_version)( + room_version_id = row["room_version_id"] + + if not room_version_id: + # this should only happen for out-of-band membership events + if not internal_metadata.get("out_of_band_membership"): + logger.warning( + "Room %s for event %s is unknown", d["room_id"], event_id + ) + continue + + # take a wild stab at the room version based on the event format + if format_version == EventFormatVersions.V1: + room_version = RoomVersions.V1 + elif format_version == EventFormatVersions.V2: + room_version = RoomVersions.V3 + else: + room_version = RoomVersions.V5 + else: + room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not room_version: + logger.error( + "Event %s in room %s has unknown room version %s", + event_id, + d["room_id"], + room_version_id, + ) + continue + + if room_version.event_format != format_version: + logger.error( + "Event %s in room %s with version %s has wrong format: " + "expected %s, was %s", + event_id, + d["room_id"], + room_version_id, + room_version.event_format, + format_version, + ) + continue + + original_ev = make_event_from_dict( event_dict=d, + room_version=room_version, internal_metadata_dict=internal_metadata, rejected_reason=rejected_reason, ) @@ -661,6 +705,12 @@ class EventsWorkerStore(SQLBaseStore): of EventFormatVersions. 'None' means the event predates EventFormatVersions (so the event is format V1). + * room_version_id (str|None): The version of the room which contains the event. + Hopefully one of RoomVersions. + + Due to historical reasons, there may be a few events in the database which + do not have an associated room; in this case None will be returned here. + * rejected_reason (str|None): if the event was rejected, the reason why. @@ -676,17 +726,18 @@ class EventsWorkerStore(SQLBaseStore): """ event_dict = {} for evs in batch_iter(event_ids, 200): - sql = ( - "SELECT " - " e.event_id, " - " e.internal_metadata," - " e.json," - " e.format_version, " - " rej.reason " - " FROM event_json as e" - " LEFT JOIN rejections as rej USING (event_id)" - " WHERE " - ) + sql = """\ + SELECT + e.event_id, + e.internal_metadata, + e.json, + e.format_version, + r.room_version, + rej.reason + FROM event_json as e + LEFT JOIN rooms r USING (room_id) + LEFT JOIN rejections as rej USING (event_id) + WHERE """ clause, args = make_in_list_sql_clause( txn.database_engine, "e.event_id", evs @@ -701,7 +752,8 @@ class EventsWorkerStore(SQLBaseStore): "internal_metadata": row[1], "json": row[2], "format_version": row[3], - "rejected_reason": row[4], + "room_version_id": row[4], + "rejected_reason": row[5], "redactions": [], } diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index 1507a14e09..925bc5691b 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -43,13 +43,40 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): def _count_users(txn): sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" - txn.execute(sql) (count,) = txn.fetchone() return count return self.db.runInteraction("count_users", _count_users) + @cached(num_args=0) + def get_monthly_active_count_by_service(self): + """Generates current count of monthly active users broken down by service. + A service is typically an appservice but also includes native matrix users. + Since the `monthly_active_users` table is populated from the `user_ips` table + `config.track_appservice_user_ips` must be set to `true` for this + method to return anything other than native matrix users. + + Returns: + Deferred[dict]: dict that includes a mapping between app_service_id + and the number of occurrences. + + """ + + def _count_users_by_service(txn): + sql = """ + SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0) + FROM monthly_active_users + LEFT JOIN users ON monthly_active_users.user_id=users.name + GROUP BY appservice_id; + """ + + txn.execute(sql) + result = txn.fetchall() + return dict(result) + + return self.db.runInteraction("count_users_by_service", _count_users_by_service) + @defer.inlineCallbacks def get_registered_reserved_users(self): """Of the reserved threepids defined in config, which are associated @@ -292,6 +319,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) self._invalidate_cache_and_stream( + txn, self.get_monthly_active_count_by_service, () + ) + self._invalidate_cache_and_stream( txn, self.user_last_seen_monthly_active, (user_id,) ) diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 49306642ed..3e53c8568a 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py
@@ -301,12 +301,16 @@ class RegistrationWorkerStore(SQLBaseStore): admin (bool): true iff the user is to be a server admin, false otherwise. """ - return self.db.simple_update_one( - table="users", - keyvalues={"name": user.to_string()}, - updatevalues={"admin": 1 if admin else 0}, - desc="set_server_admin", - ) + + def set_server_admin_txn(txn): + self.db.simple_update_one_txn( + txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} + ) + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user.to_string(),) + ) + + return self.db.runInteraction("set_server_admin", set_server_admin_txn) def _query_for_auth(self, txn, token): sql = ( diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres new file mode 100644
index 0000000000..92aaadde0d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres
@@ -0,0 +1,39 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- When we first added the room_version column to the rooms table, it was populated from +-- the current_state_events table. However, there was an issue causing a background +-- update to clean up the current_state_events table for rooms where the server is no +-- longer participating, before that column could be populated. Therefore, some rooms had +-- a NULL room_version. + +-- The rooms_version_column_2.sql.* delta files were introduced to make the populating +-- synchronous instead of running it in a background update, which fixed this issue. +-- However, all of the instances of Synapse installed or updated in the meantime got +-- their rooms table corrupted with NULL room_versions. + +-- This query fishes out the room versions from the create event using the state_events +-- table instead of the current_state_events one, as the former still have all of the +-- create events. + +UPDATE rooms SET room_version=( + SELECT COALESCE(json::json->'content'->>'room_version','1') + FROM state_events se INNER JOIN event_json ej USING (event_id) + WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key='' + LIMIT 1 +) WHERE rooms.room_version IS NULL; + +-- see also rooms_version_column_3.sql.sqlite which has a copy of the above query, using +-- sqlite syntax for the json extraction. diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite new file mode 100644
index 0000000000..e19dab97cb --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite
@@ -0,0 +1,23 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- see rooms_version_column_3.sql.postgres for details of what's going on here. + +UPDATE rooms SET room_version=( + SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1') + FROM state_events se INNER JOIN event_json ej USING (event_id) + WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key='' + LIMIT 1 +) WHERE rooms.room_version IS NULL; diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py
index 12c982cb26..725e12507f 100644 --- a/synapse/storage/data_stores/main/state_deltas.py +++ b/synapse/storage/data_stores/main/state_deltas.py
@@ -15,6 +15,8 @@ import logging +from twisted.internet import defer + from synapse.storage._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -56,7 +58,7 @@ class StateDeltasStore(SQLBaseStore): # if the CSDs haven't changed between prev_stream_id and now, we # know for certain that they haven't changed between prev_stream_id and # max_stream_id. - return max_stream_id, [] + return defer.succeed((max_stream_id, [])) def get_current_state_deltas_txn(txn): # First we calculate the max stream id that will give us less than diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 1953614401..e61595336c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -15,9 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import sys import time -from typing import Iterable, Tuple +from time import monotonic as monotonic_time +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple from six import iteritems, iterkeys, itervalues from six.moves import intern, range @@ -29,27 +29,21 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig -from synapse.logging.context import LoggingContext, make_deferred_yieldable +from synapse.logging.context import ( + LoggingContext, + LoggingContextOrSentinel, + make_deferred_yieldable, +) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.types import Connection, Cursor from synapse.util.stringutils import exception_to_unicode -# import a function which will return a monotonic time, in seconds -try: - # on python 3, use time.monotonic, since time.clock can go backwards - from time import monotonic as monotonic_time -except ImportError: - # ... but python 2 doesn't have it - from time import clock as monotonic_time - logger = logging.getLogger(__name__) -try: - MAX_TXN_ID = sys.maxint - 1 -except AttributeError: - # python 3 does not have a maximum int value - MAX_TXN_ID = 2 ** 63 - 1 +# python 3 does not have a maximum int value +MAX_TXN_ID = 2 ** 63 - 1 sql_logger = logging.getLogger("synapse.storage.SQL") transaction_logger = logging.getLogger("synapse.storage.txn") @@ -77,7 +71,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { def make_pool( - reactor, db_config: DatabaseConnectionConfig, engine + reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine ) -> adbapi.ConnectionPool: """Get the connection pool for the database. """ @@ -90,7 +84,9 @@ def make_pool( ) -def make_conn(db_config: DatabaseConnectionConfig, engine): +def make_conn( + db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine +) -> Connection: """Make a new connection to the database and return it. Returns: @@ -107,20 +103,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine): return db_conn -class LoggingTransaction(object): +# The type of entry which goes on our after_callbacks and exception_callbacks lists. +# +# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so +# that mypy sees the type but the runtime python doesn't. +_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]] + + +class LoggingTransaction: """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() method. Args: txn: The database transcation object to wrap. - name (str): The name of this transactions for logging. - database_engine (Sqlite3Engine|PostgresEngine) - after_callbacks(list|None): A list that callbacks will be appended to + name: The name of this transactions for logging. + database_engine + after_callbacks: A list that callbacks will be appended to that have been added by `call_after` which should be run on successful completion of the transaction. None indicates that no callbacks should be allowed to be scheduled to run. - exception_callbacks(list|None): A list that callbacks will be appended + exception_callbacks: A list that callbacks will be appended to that have been added by `call_on_exception` which should be run if transaction ends with an error. None indicates that no callbacks should be allowed to be scheduled to run. @@ -135,46 +138,67 @@ class LoggingTransaction(object): ] def __init__( - self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None + self, + txn: Cursor, + name: str, + database_engine: BaseDatabaseEngine, + after_callbacks: Optional[List[_CallbackListEntry]] = None, + exception_callbacks: Optional[List[_CallbackListEntry]] = None, ): - object.__setattr__(self, "txn", txn) - object.__setattr__(self, "name", name) - object.__setattr__(self, "database_engine", database_engine) - object.__setattr__(self, "after_callbacks", after_callbacks) - object.__setattr__(self, "exception_callbacks", exception_callbacks) + self.txn = txn + self.name = name + self.database_engine = database_engine + self.after_callbacks = after_callbacks + self.exception_callbacks = exception_callbacks - def call_after(self, callback, *args, **kwargs): + def call_after(self, callback: "Callable[..., None]", *args, **kwargs): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. """ + # if self.after_callbacks is None, that means that whatever constructed the + # LoggingTransaction isn't expecting there to be any callbacks; assert that + # is not the case. + assert self.after_callbacks is not None self.after_callbacks.append((callback, args, kwargs)) - def call_on_exception(self, callback, *args, **kwargs): + def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs): + # if self.exception_callbacks is None, that means that whatever constructed the + # LoggingTransaction isn't expecting there to be any callbacks; assert that + # is not the case. + assert self.exception_callbacks is not None self.exception_callbacks.append((callback, args, kwargs)) - def __getattr__(self, name): - return getattr(self.txn, name) + def fetchall(self) -> List[Tuple]: + return self.txn.fetchall() - def __setattr__(self, name, value): - setattr(self.txn, name, value) + def fetchone(self) -> Tuple: + return self.txn.fetchone() - def __iter__(self): + def __iter__(self) -> Iterator[Tuple]: return self.txn.__iter__() + @property + def rowcount(self) -> int: + return self.txn.rowcount + + @property + def description(self) -> Any: + return self.txn.description + def execute_batch(self, sql, args): if isinstance(self.database_engine, PostgresEngine): - from psycopg2.extras import execute_batch + from psycopg2.extras import execute_batch # type: ignore self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) else: for val in args: self.execute(sql, val) - def execute(self, sql, *args): + def execute(self, sql: str, *args: Any): self._do_execute(self.txn.execute, sql, *args) - def executemany(self, sql, *args): + def executemany(self, sql: str, *args: Any): self._do_execute(self.txn.executemany, sql, *args) def _make_sql_one_line(self, sql): @@ -207,6 +231,9 @@ class LoggingTransaction(object): sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) sql_query_timer.labels(sql.split()[0]).observe(secs) + def close(self): + self.txn.close() + class PerformanceCounters(object): def __init__(self): @@ -251,7 +278,9 @@ class Database(object): _TXN_ID = 0 - def __init__(self, hs, database_config: DatabaseConnectionConfig, engine): + def __init__( + self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + ): self.hs = hs self._clock = hs.get_clock() self._database_config = database_config @@ -259,9 +288,9 @@ class Database(object): self.updates = BackgroundUpdater(hs, self) - self._previous_txn_total_time = 0 - self._current_txn_total_time = 0 - self._previous_loop_ts = 0 + self._previous_txn_total_time = 0.0 + self._current_txn_total_time = 0.0 + self._previous_loop_ts = 0.0 # TODO(paul): These can eventually be removed once the metrics code # is running in mainline, and we have some nice monitoring frontends @@ -463,23 +492,23 @@ class Database(object): sql_txn_timer.labels(desc).observe(duration) @defer.inlineCallbacks - def runInteraction(self, desc, func, *args, **kwargs): + def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any): """Starts a transaction on the database and runs a given function Arguments: - desc (str): description of the transaction, for logging and metrics - func (func): callback function, which will be called with a + desc: description of the transaction, for logging and metrics + func: callback function, which will be called with a database transaction (twisted.enterprise.adbapi.Transaction) as its first argument, followed by `args` and `kwargs`. - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` + args: positional args to pass to `func` + kwargs: named args to pass to `func` Returns: Deferred: The result of func """ - after_callbacks = [] - exception_callbacks = [] + after_callbacks = [] # type: List[_CallbackListEntry] + exception_callbacks = [] # type: List[_CallbackListEntry] if LoggingContext.current_context() == LoggingContext.sentinel: logger.warning("Starting db txn '%s' from sentinel context", desc) @@ -505,20 +534,22 @@ class Database(object): return result @defer.inlineCallbacks - def runWithConnection(self, func, *args, **kwargs): + def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any): """Wraps the .runWithConnection() method on the underlying db_pool. Arguments: - func (func): callback function, which will be called with a + func: callback function, which will be called with a database connection (twisted.enterprise.adbapi.Connection) as its first argument, followed by `args` and `kwargs`. - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` + args: positional args to pass to `func` + kwargs: named args to pass to `func` Returns: Deferred: The result of func """ - parent_context = LoggingContext.current_context() + parent_context = ( + LoggingContext.current_context() + ) # type: Optional[LoggingContextOrSentinel] if parent_context == LoggingContext.sentinel: logger.warning( "Starting db connection from sentinel context: metrics will be lost" @@ -800,7 +831,7 @@ class Database(object): return False # We didn't find any existing rows, so insert a new one - allvalues = {} + allvalues = {} # type: Dict[str, Any] allvalues.update(keyvalues) allvalues.update(values) allvalues.update(insertion_values) @@ -829,7 +860,7 @@ class Database(object): Returns: None """ - allvalues = {} + allvalues = {} # type: Dict[str, Any] allvalues.update(keyvalues) allvalues.update(insertion_values) @@ -916,7 +947,7 @@ class Database(object): Returns: None """ - allnames = [] + allnames = [] # type: List[str] allnames.extend(key_names) allnames.extend(value_names) @@ -1100,7 +1131,7 @@ class Database(object): keyvalues : dict of column names and values to select the rows with retcols : list of strings giving the names of the columns to return """ - results = [] + results = [] # type: List[Dict[str, Any]] if not iterable: return results @@ -1439,7 +1470,7 @@ class Database(object): raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") where_clause = "WHERE " if filters or keyvalues else "" - arg_list = [] + arg_list = [] # type: List[Any] if filters: where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters) arg_list += list(filters.values()) diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 9d2d519922..035f9ea6e9 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py
@@ -12,29 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import importlib import platform -from ._base import IncorrectDatabaseSetup +from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup from .postgres import PostgresEngine from .sqlite import Sqlite3Engine -SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine} - -def create_engine(database_config): +def create_engine(database_config) -> BaseDatabaseEngine: name = database_config["name"] - engine_class = SUPPORTED_MODULE.get(name, None) - if engine_class: + if name == "sqlite3": + import sqlite3 + + return Sqlite3Engine(sqlite3, database_config) + + if name == "psycopg2": # pypy requires psycopg2cffi rather than psycopg2 - if name == "psycopg2" and platform.python_implementation() == "PyPy": - name = "psycopg2cffi" - module = importlib.import_module(name) - return engine_class(module, database_config) + if platform.python_implementation() == "PyPy": + import psycopg2cffi as psycopg2 # type: ignore + else: + import psycopg2 # type: ignore + + return PostgresEngine(psycopg2, database_config) raise RuntimeError("Unsupported database engine '%s'" % (name,)) -__all__ = ["create_engine", "IncorrectDatabaseSetup"] +__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"] diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ec5a4d198b..ab0bbe4bd3 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py
@@ -12,7 +12,94 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc +from typing import Generic, TypeVar + +from synapse.storage.types import Connection class IncorrectDatabaseSetup(RuntimeError): pass + + +ConnectionType = TypeVar("ConnectionType", bound=Connection) + + +class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): + def __init__(self, module, database_config: dict): + self.module = module + + @property + @abc.abstractmethod + def single_threaded(self) -> bool: + ... + + @property + @abc.abstractmethod + def can_native_upsert(self) -> bool: + """ + Do we support native UPSERTs? + """ + ... + + @property + @abc.abstractmethod + def supports_tuple_comparison(self) -> bool: + """ + Do we support comparing tuples, i.e. `(a, b) > (c, d)`? + """ + ... + + @property + @abc.abstractmethod + def supports_using_any_list(self) -> bool: + """ + Do we support using `a = ANY(?)` and passing a list + """ + ... + + @abc.abstractmethod + def check_database( + self, db_conn: ConnectionType, allow_outdated_version: bool = False + ) -> None: + ... + + @abc.abstractmethod + def check_new_database(self, txn) -> None: + """Gets called when setting up a brand new database. This allows us to + apply stricter checks on new databases versus existing database. + """ + ... + + @abc.abstractmethod + def convert_param_style(self, sql: str) -> str: + ... + + @abc.abstractmethod + def on_new_connection(self, db_conn: ConnectionType) -> None: + ... + + @abc.abstractmethod + def is_deadlock(self, error: Exception) -> bool: + ... + + @abc.abstractmethod + def is_connection_closed(self, conn: ConnectionType) -> bool: + ... + + @abc.abstractmethod + def lock_table(self, txn, table: str) -> None: + ... + + @abc.abstractmethod + def get_next_state_group_id(self, txn) -> int: + """Returns an int that can be used as a new state_group ID + """ + ... + + @property + @abc.abstractmethod + def server_version(self) -> str: + """Gets a string giving the server version. For example: '3.22.0' + """ + ... diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 53b3f372b0..6c7d08a6f2 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py
@@ -15,16 +15,14 @@ import logging -from ._base import IncorrectDatabaseSetup +from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup logger = logging.getLogger(__name__) -class PostgresEngine(object): - single_threaded = False - +class PostgresEngine(BaseDatabaseEngine): def __init__(self, database_module, database_config): - self.module = database_module + super().__init__(database_module, database_config) self.module.extensions.register_type(self.module.extensions.UNICODE) # Disables passing `bytes` to txn.execute, c.f. #6186. If you do @@ -36,6 +34,10 @@ class PostgresEngine(object): self.synchronous_commit = database_config.get("synchronous_commit", True) self._version = None # unknown as yet + @property + def single_threaded(self) -> bool: + return False + def check_database(self, db_conn, allow_outdated_version: bool = False): # Get the version of PostgreSQL that we're using. As per the psycopg2 # docs: The number is formed by converting the major, minor, and diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 641e490697..3bc2e8b986 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py
@@ -12,16 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import struct import threading +import typing + +from synapse.storage.engines import BaseDatabaseEngine +if typing.TYPE_CHECKING: + import sqlite3 # noqa: F401 -class Sqlite3Engine(object): - single_threaded = True +class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): def __init__(self, database_module, database_config): - self.module = database_module + super().__init__(database_module, database_config) database = database_config.get("args", {}).get("database") self._is_in_memory = database in (None, ":memory:",) @@ -32,6 +35,10 @@ class Sqlite3Engine(object): self._current_state_group_id_lock = threading.Lock() @property + def single_threaded(self) -> bool: + return True + + @property def can_native_upsert(self): """ Do we support native UPSERTs? This requires SQLite3 3.24+, plus some @@ -68,7 +75,6 @@ class Sqlite3Engine(object): return sql def on_new_connection(self, db_conn): - # We need to import here to avoid an import loop. from synapse.storage.prepare_database import prepare_database diff --git a/synapse/storage/types.py b/synapse/storage/types.py new file mode 100644
index 0000000000..daff81c5ee --- /dev/null +++ b/synapse/storage/types.py
@@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Iterable, Iterator, List, Tuple + +from typing_extensions import Protocol + + +""" +Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 +""" + + +class Cursor(Protocol): + def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any: + ... + + def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any: + ... + + def fetchall(self) -> List[Tuple]: + ... + + def fetchone(self) -> Tuple: + ... + + @property + def description(self) -> Any: + return None + + @property + def rowcount(self) -> int: + return 0 + + def __iter__(self) -> Iterator[Tuple]: + ... + + def close(self) -> None: + ... + + +class Connection(Protocol): + def cursor(self) -> Cursor: + ... + + def close(self) -> None: + ... + + def commit(self) -> None: + ... + + def rollback(self, *args, **kwargs) -> None: + ...