diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index d55733a4cd..8af5f7de54 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -317,14 +317,16 @@ class DeviceWorkerStore(SQLBaseStore):
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
+
+ if "signatures" in device:
+ for sig_user_id, sigs in device["signatures"].items():
+ result["keys"].setdefault("signatures", {}).setdefault(
+ sig_user_id, {}
+ ).update(sigs)
+
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
- result["keys"].setdefault("signatures", {}).setdefault(
- sig_user_id, {}
- ).update(sigs)
else:
result["deleted"] = True
@@ -525,14 +527,16 @@ class DeviceWorkerStore(SQLBaseStore):
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
+
+ if "signatures" in device:
+ for sig_user_id, sigs in device["signatures"].items():
+ result["keys"].setdefault("signatures", {}).setdefault(
+ sig_user_id, {}
+ ).update(sigs)
+
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
- result["keys"].setdefault("signatures", {}).setdefault(
- sig_user_id, {}
- ).update(sigs)
results.append(result)
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 49a7b8b433..62d4e9f599 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -14,7 +14,7 @@
# limitations under the License.
import itertools
import logging
-from typing import List, Optional, Set
+from typing import Dict, List, Optional, Set, Tuple
from six.moves.queue import Empty, PriorityQueue
@@ -103,6 +103,154 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
+ def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+ """Given sets of state events figure out the auth chain difference (as
+ per state res v2 algorithm).
+
+ This equivalent to fetching the full auth chain for each set of state
+ and returning the events that don't appear in each and every auth
+ chain.
+
+ Returns:
+ Deferred[Set[str]]
+ """
+
+ return self.db.runInteraction(
+ "get_auth_chain_difference",
+ self._get_auth_chain_difference_txn,
+ state_sets,
+ )
+
+ def _get_auth_chain_difference_txn(
+ self, txn, state_sets: List[Set[str]]
+ ) -> Set[str]:
+
+ # Algorithm Description
+ # ~~~~~~~~~~~~~~~~~~~~~
+ #
+ # The idea here is to basically walk the auth graph of each state set in
+ # tandem, keeping track of which auth events are reachable by each state
+ # set. If we reach an auth event we've already visited (via a different
+ # state set) then we mark that auth event and all ancestors as reachable
+ # by the state set. This requires that we keep track of the auth chains
+ # in memory.
+ #
+ # Doing it in a such a way means that we can stop early if all auth
+ # events we're currently walking are reachable by all state sets.
+ #
+ # *Note*: We can't stop walking an event's auth chain if it is reachable
+ # by all state sets. This is because other auth chains we're walking
+ # might be reachable only via the original auth chain. For example,
+ # given the following auth chain:
+ #
+ # A -> C -> D -> E
+ # / /
+ # B -´---------´
+ #
+ # and state sets {A} and {B} then walking the auth chains of A and B
+ # would immediately show that C is reachable by both. However, if we
+ # stopped at C then we'd only reach E via the auth chain of B and so E
+ # would errornously get included in the returned difference.
+ #
+ # The other thing that we do is limit the number of auth chains we walk
+ # at once, due to practical limits (i.e. we can only query the database
+ # with a limited set of parameters). We pick the auth chains we walk
+ # each iteration based on their depth, in the hope that events with a
+ # lower depth are likely reachable by those with higher depths.
+ #
+ # We could use any ordering that we believe would give a rough
+ # topological ordering, e.g. origin server timestamp. If the ordering
+ # chosen is not topological then the algorithm still produces the right
+ # result, but perhaps a bit more inefficiently. This is why it is safe
+ # to use "depth" here.
+
+ initial_events = set(state_sets[0]).union(*state_sets[1:])
+
+ # Dict from events in auth chains to which sets *cannot* reach them.
+ # I.e. if the set is empty then all sets can reach the event.
+ event_to_missing_sets = {
+ event_id: {i for i, a in enumerate(state_sets) if event_id not in a}
+ for event_id in initial_events
+ }
+
+ # We need to get the depth of the initial events for sorting purposes.
+ sql = """
+ SELECT depth, event_id FROM events
+ WHERE %s
+ ORDER BY depth ASC
+ """
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", initial_events
+ )
+ txn.execute(sql % (clause,), args)
+
+ # The sorted list of events whose auth chains we should walk.
+ search = txn.fetchall() # type: List[Tuple[int, str]]
+
+ # Map from event to its auth events
+ event_to_auth_events = {} # type: Dict[str, Set[str]]
+
+ base_sql = """
+ SELECT a.event_id, auth_id, depth
+ FROM event_auth AS a
+ INNER JOIN events AS e ON (e.event_id = a.auth_id)
+ WHERE
+ """
+
+ while search:
+ # Check whether all our current walks are reachable by all state
+ # sets. If so we can bail.
+ if all(not event_to_missing_sets[eid] for _, eid in search):
+ break
+
+ # Fetch the auth events and their depths of the N last events we're
+ # currently walking
+ search, chunk = search[:-100], search[-100:]
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
+ )
+ txn.execute(base_sql + clause, args)
+
+ for event_id, auth_event_id, auth_event_depth in txn:
+ event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
+
+ sets = event_to_missing_sets.get(auth_event_id)
+ if sets is None:
+ # First time we're seeing this event, so we add it to the
+ # queue of things to fetch.
+ search.append((auth_event_depth, auth_event_id))
+
+ # Assume that this event is unreachable from any of the
+ # state sets until proven otherwise
+ sets = event_to_missing_sets[auth_event_id] = set(
+ range(len(state_sets))
+ )
+ else:
+ # We've previously seen this event, so look up its auth
+ # events and recursively mark all ancestors as reachable
+ # by the current event's state set.
+ a_ids = event_to_auth_events.get(auth_event_id)
+ while a_ids:
+ new_aids = set()
+ for a_id in a_ids:
+ event_to_missing_sets[a_id].intersection_update(
+ event_to_missing_sets[event_id]
+ )
+
+ b = event_to_auth_events.get(a_id)
+ if b:
+ new_aids.update(b)
+
+ a_ids = new_aids
+
+ # Mark that the auth event is reachable by the approriate sets.
+ sets.intersection_update(event_to_missing_sets[event_id])
+
+ search.sort()
+
+ # Return all events where not all sets can reach them.
+ return {eid for eid, n in event_to_missing_sets.items() if n}
+
def get_oldest_events_in_room(self, room_id):
return self.db.runInteraction(
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 9988a6d3fc..8eed590929 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -608,6 +608,23 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return range_end
+ @defer.inlineCallbacks
+ def get_time_of_last_push_action_before(self, stream_ordering):
+ def f(txn):
+ sql = (
+ "SELECT e.received_ts"
+ " FROM event_push_actions AS ep"
+ " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
+ " WHERE ep.stream_ordering > ?"
+ " ORDER BY ep.stream_ordering ASC"
+ " LIMIT 1"
+ )
+ txn.execute(sql, (stream_ordering,))
+ return txn.fetchone()
+
+ result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
+ return result[0] if result else None
+
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
@@ -736,23 +753,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
return push_actions
@defer.inlineCallbacks
- def get_time_of_last_push_action_before(self, stream_ordering):
- def f(txn):
- sql = (
- "SELECT e.received_ts"
- " FROM event_push_actions AS ep"
- " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
- " WHERE ep.stream_ordering > ?"
- " ORDER BY ep.stream_ordering ASC"
- " LIMIT 1"
- )
- txn.execute(sql, (stream_ordering,))
- return txn.fetchone()
-
- result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
- return result[0] if result else None
-
- @defer.inlineCallbacks
def get_latest_push_action_stream_ordering(self):
def f(txn):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 8ae23df00a..d593ef47b8 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -1168,7 +1168,11 @@ class EventsStore(
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
- pruned_json = encode_json(prune_event_dict(original_event.get_dict()))
+ pruned_json = encode_json(
+ prune_event_dict(
+ original_event.room_version, original_event.get_dict()
+ )
+ )
else:
# Redaction wasn't allowed
pruned_json = None
@@ -1929,7 +1933,9 @@ class EventsStore(
return
# Prune the event's dict then convert it to JSON.
- pruned_json = encode_json(prune_event_dict(event.get_dict()))
+ pruned_json = encode_json(
+ prune_event_dict(event.room_version, event.get_dict())
+ )
# Update the event_json table to replace the event's JSON with the pruned
# JSON.
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 3704995bad..8d2bfeb4a3 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -28,9 +28,12 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError
-from synapse.api.room_versions import EventFormatVersions
-from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
-from synapse.events.snapshot import EventContext # noqa: F401
+from synapse.api.room_versions import (
+ KNOWN_ROOM_VERSIONS,
+ EventFormatVersions,
+ RoomVersions,
+)
+from synapse.events import make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -580,8 +583,49 @@ class EventsWorkerStore(SQLBaseStore):
# of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1
- original_ev = event_type_from_format_version(format_version)(
+ room_version_id = row["room_version_id"]
+
+ if not room_version_id:
+ # this should only happen for out-of-band membership events
+ if not internal_metadata.get("out_of_band_membership"):
+ logger.warning(
+ "Room %s for event %s is unknown", d["room_id"], event_id
+ )
+ continue
+
+ # take a wild stab at the room version based on the event format
+ if format_version == EventFormatVersions.V1:
+ room_version = RoomVersions.V1
+ elif format_version == EventFormatVersions.V2:
+ room_version = RoomVersions.V3
+ else:
+ room_version = RoomVersions.V5
+ else:
+ room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
+ if not room_version:
+ logger.error(
+ "Event %s in room %s has unknown room version %s",
+ event_id,
+ d["room_id"],
+ room_version_id,
+ )
+ continue
+
+ if room_version.event_format != format_version:
+ logger.error(
+ "Event %s in room %s with version %s has wrong format: "
+ "expected %s, was %s",
+ event_id,
+ d["room_id"],
+ room_version_id,
+ room_version.event_format,
+ format_version,
+ )
+ continue
+
+ original_ev = make_event_from_dict(
event_dict=d,
+ room_version=room_version,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
@@ -661,6 +705,12 @@ class EventsWorkerStore(SQLBaseStore):
of EventFormatVersions. 'None' means the event predates
EventFormatVersions (so the event is format V1).
+ * room_version_id (str|None): The version of the room which contains the event.
+ Hopefully one of RoomVersions.
+
+ Due to historical reasons, there may be a few events in the database which
+ do not have an associated room; in this case None will be returned here.
+
* rejected_reason (str|None): if the event was rejected, the reason
why.
@@ -676,17 +726,18 @@ class EventsWorkerStore(SQLBaseStore):
"""
event_dict = {}
for evs in batch_iter(event_ids, 200):
- sql = (
- "SELECT "
- " e.event_id, "
- " e.internal_metadata,"
- " e.json,"
- " e.format_version, "
- " rej.reason "
- " FROM event_json as e"
- " LEFT JOIN rejections as rej USING (event_id)"
- " WHERE "
- )
+ sql = """\
+ SELECT
+ e.event_id,
+ e.internal_metadata,
+ e.json,
+ e.format_version,
+ r.room_version,
+ rej.reason
+ FROM event_json as e
+ LEFT JOIN rooms r USING (room_id)
+ LEFT JOIN rejections as rej USING (event_id)
+ WHERE """
clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", evs
@@ -701,7 +752,8 @@ class EventsWorkerStore(SQLBaseStore):
"internal_metadata": row[1],
"json": row[2],
"format_version": row[3],
- "rejected_reason": row[4],
+ "room_version_id": row[4],
+ "rejected_reason": row[5],
"redactions": [],
}
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index 1507a14e09..925bc5691b 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -43,13 +43,40 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
def _count_users(txn):
sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
-
txn.execute(sql)
(count,) = txn.fetchone()
return count
return self.db.runInteraction("count_users", _count_users)
+ @cached(num_args=0)
+ def get_monthly_active_count_by_service(self):
+ """Generates current count of monthly active users broken down by service.
+ A service is typically an appservice but also includes native matrix users.
+ Since the `monthly_active_users` table is populated from the `user_ips` table
+ `config.track_appservice_user_ips` must be set to `true` for this
+ method to return anything other than native matrix users.
+
+ Returns:
+ Deferred[dict]: dict that includes a mapping between app_service_id
+ and the number of occurrences.
+
+ """
+
+ def _count_users_by_service(txn):
+ sql = """
+ SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0)
+ FROM monthly_active_users
+ LEFT JOIN users ON monthly_active_users.user_id=users.name
+ GROUP BY appservice_id;
+ """
+
+ txn.execute(sql)
+ result = txn.fetchall()
+ return dict(result)
+
+ return self.db.runInteraction("count_users_by_service", _count_users_by_service)
+
@defer.inlineCallbacks
def get_registered_reserved_users(self):
"""Of the reserved threepids defined in config, which are associated
@@ -292,6 +319,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
self._invalidate_cache_and_stream(
+ txn, self.get_monthly_active_count_by_service, ()
+ )
+ self._invalidate_cache_and_stream(
txn, self.user_last_seen_monthly_active, (user_id,)
)
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 49306642ed..3e53c8568a 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -301,12 +301,16 @@ class RegistrationWorkerStore(SQLBaseStore):
admin (bool): true iff the user is to be a server admin,
false otherwise.
"""
- return self.db.simple_update_one(
- table="users",
- keyvalues={"name": user.to_string()},
- updatevalues={"admin": 1 if admin else 0},
- desc="set_server_admin",
- )
+
+ def set_server_admin_txn(txn):
+ self.db.simple_update_one_txn(
+ txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_id, (user.to_string(),)
+ )
+
+ return self.db.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token):
sql = (
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres
new file mode 100644
index 0000000000..92aaadde0d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres
@@ -0,0 +1,39 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- When we first added the room_version column to the rooms table, it was populated from
+-- the current_state_events table. However, there was an issue causing a background
+-- update to clean up the current_state_events table for rooms where the server is no
+-- longer participating, before that column could be populated. Therefore, some rooms had
+-- a NULL room_version.
+
+-- The rooms_version_column_2.sql.* delta files were introduced to make the populating
+-- synchronous instead of running it in a background update, which fixed this issue.
+-- However, all of the instances of Synapse installed or updated in the meantime got
+-- their rooms table corrupted with NULL room_versions.
+
+-- This query fishes out the room versions from the create event using the state_events
+-- table instead of the current_state_events one, as the former still have all of the
+-- create events.
+
+UPDATE rooms SET room_version=(
+ SELECT COALESCE(json::json->'content'->>'room_version','1')
+ FROM state_events se INNER JOIN event_json ej USING (event_id)
+ WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key=''
+ LIMIT 1
+) WHERE rooms.room_version IS NULL;
+
+-- see also rooms_version_column_3.sql.sqlite which has a copy of the above query, using
+-- sqlite syntax for the json extraction.
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite
new file mode 100644
index 0000000000..e19dab97cb
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite
@@ -0,0 +1,23 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- see rooms_version_column_3.sql.postgres for details of what's going on here.
+
+UPDATE rooms SET room_version=(
+ SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1')
+ FROM state_events se INNER JOIN event_json ej USING (event_id)
+ WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key=''
+ LIMIT 1
+) WHERE rooms.room_version IS NULL;
diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py
index 12c982cb26..725e12507f 100644
--- a/synapse/storage/data_stores/main/state_deltas.py
+++ b/synapse/storage/data_stores/main/state_deltas.py
@@ -15,6 +15,8 @@
import logging
+from twisted.internet import defer
+
from synapse.storage._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -56,7 +58,7 @@ class StateDeltasStore(SQLBaseStore):
# if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and
# max_stream_id.
- return max_stream_id, []
+ return defer.succeed((max_stream_id, []))
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 1953614401..e61595336c 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -15,9 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import sys
import time
-from typing import Iterable, Tuple
+from time import monotonic as monotonic_time
+from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from six import iteritems, iterkeys, itervalues
from six.moves import intern, range
@@ -29,27 +29,21 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
-from synapse.logging.context import LoggingContext, make_deferred_yieldable
+from synapse.logging.context import (
+ LoggingContext,
+ LoggingContextOrSentinel,
+ make_deferred_yieldable,
+)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.types import Connection, Cursor
from synapse.util.stringutils import exception_to_unicode
-# import a function which will return a monotonic time, in seconds
-try:
- # on python 3, use time.monotonic, since time.clock can go backwards
- from time import monotonic as monotonic_time
-except ImportError:
- # ... but python 2 doesn't have it
- from time import clock as monotonic_time
-
logger = logging.getLogger(__name__)
-try:
- MAX_TXN_ID = sys.maxint - 1
-except AttributeError:
- # python 3 does not have a maximum int value
- MAX_TXN_ID = 2 ** 63 - 1
+# python 3 does not have a maximum int value
+MAX_TXN_ID = 2 ** 63 - 1
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
@@ -77,7 +71,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
def make_pool(
- reactor, db_config: DatabaseConnectionConfig, engine
+ reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> adbapi.ConnectionPool:
"""Get the connection pool for the database.
"""
@@ -90,7 +84,9 @@ def make_pool(
)
-def make_conn(db_config: DatabaseConnectionConfig, engine):
+def make_conn(
+ db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+) -> Connection:
"""Make a new connection to the database and return it.
Returns:
@@ -107,20 +103,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
return db_conn
-class LoggingTransaction(object):
+# The type of entry which goes on our after_callbacks and exception_callbacks lists.
+#
+# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
+# that mypy sees the type but the runtime python doesn't.
+_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
+
+
+class LoggingTransaction:
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method.
Args:
txn: The database transcation object to wrap.
- name (str): The name of this transactions for logging.
- database_engine (Sqlite3Engine|PostgresEngine)
- after_callbacks(list|None): A list that callbacks will be appended to
+ name: The name of this transactions for logging.
+ database_engine
+ after_callbacks: A list that callbacks will be appended to
that have been added by `call_after` which should be run on
successful completion of the transaction. None indicates that no
callbacks should be allowed to be scheduled to run.
- exception_callbacks(list|None): A list that callbacks will be appended
+ exception_callbacks: A list that callbacks will be appended
to that have been added by `call_on_exception` which should be run
if transaction ends with an error. None indicates that no callbacks
should be allowed to be scheduled to run.
@@ -135,46 +138,67 @@ class LoggingTransaction(object):
]
def __init__(
- self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
+ self,
+ txn: Cursor,
+ name: str,
+ database_engine: BaseDatabaseEngine,
+ after_callbacks: Optional[List[_CallbackListEntry]] = None,
+ exception_callbacks: Optional[List[_CallbackListEntry]] = None,
):
- object.__setattr__(self, "txn", txn)
- object.__setattr__(self, "name", name)
- object.__setattr__(self, "database_engine", database_engine)
- object.__setattr__(self, "after_callbacks", after_callbacks)
- object.__setattr__(self, "exception_callbacks", exception_callbacks)
+ self.txn = txn
+ self.name = name
+ self.database_engine = database_engine
+ self.after_callbacks = after_callbacks
+ self.exception_callbacks = exception_callbacks
- def call_after(self, callback, *args, **kwargs):
+ def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
"""
+ # if self.after_callbacks is None, that means that whatever constructed the
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
+ # is not the case.
+ assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs))
- def call_on_exception(self, callback, *args, **kwargs):
+ def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+ # if self.exception_callbacks is None, that means that whatever constructed the
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
+ # is not the case.
+ assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))
- def __getattr__(self, name):
- return getattr(self.txn, name)
+ def fetchall(self) -> List[Tuple]:
+ return self.txn.fetchall()
- def __setattr__(self, name, value):
- setattr(self.txn, name, value)
+ def fetchone(self) -> Tuple:
+ return self.txn.fetchone()
- def __iter__(self):
+ def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__()
+ @property
+ def rowcount(self) -> int:
+ return self.txn.rowcount
+
+ @property
+ def description(self) -> Any:
+ return self.txn.description
+
def execute_batch(self, sql, args):
if isinstance(self.database_engine, PostgresEngine):
- from psycopg2.extras import execute_batch
+ from psycopg2.extras import execute_batch # type: ignore
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
for val in args:
self.execute(sql, val)
- def execute(self, sql, *args):
+ def execute(self, sql: str, *args: Any):
self._do_execute(self.txn.execute, sql, *args)
- def executemany(self, sql, *args):
+ def executemany(self, sql: str, *args: Any):
self._do_execute(self.txn.executemany, sql, *args)
def _make_sql_one_line(self, sql):
@@ -207,6 +231,9 @@ class LoggingTransaction(object):
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs)
+ def close(self):
+ self.txn.close()
+
class PerformanceCounters(object):
def __init__(self):
@@ -251,7 +278,9 @@ class Database(object):
_TXN_ID = 0
- def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
+ def __init__(
+ self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+ ):
self.hs = hs
self._clock = hs.get_clock()
self._database_config = database_config
@@ -259,9 +288,9 @@ class Database(object):
self.updates = BackgroundUpdater(hs, self)
- self._previous_txn_total_time = 0
- self._current_txn_total_time = 0
- self._previous_loop_ts = 0
+ self._previous_txn_total_time = 0.0
+ self._current_txn_total_time = 0.0
+ self._previous_loop_ts = 0.0
# TODO(paul): These can eventually be removed once the metrics code
# is running in mainline, and we have some nice monitoring frontends
@@ -463,23 +492,23 @@ class Database(object):
sql_txn_timer.labels(desc).observe(duration)
@defer.inlineCallbacks
- def runInteraction(self, desc, func, *args, **kwargs):
+ def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
"""Starts a transaction on the database and runs a given function
Arguments:
- desc (str): description of the transaction, for logging and metrics
- func (func): callback function, which will be called with a
+ desc: description of the transaction, for logging and metrics
+ func: callback function, which will be called with a
database transaction (twisted.enterprise.adbapi.Transaction) as
its first argument, followed by `args` and `kwargs`.
- args (list): positional args to pass to `func`
- kwargs (dict): named args to pass to `func`
+ args: positional args to pass to `func`
+ kwargs: named args to pass to `func`
Returns:
Deferred: The result of func
"""
- after_callbacks = []
- exception_callbacks = []
+ after_callbacks = [] # type: List[_CallbackListEntry]
+ exception_callbacks = [] # type: List[_CallbackListEntry]
if LoggingContext.current_context() == LoggingContext.sentinel:
logger.warning("Starting db txn '%s' from sentinel context", desc)
@@ -505,20 +534,22 @@ class Database(object):
return result
@defer.inlineCallbacks
- def runWithConnection(self, func, *args, **kwargs):
+ def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
"""Wraps the .runWithConnection() method on the underlying db_pool.
Arguments:
- func (func): callback function, which will be called with a
+ func: callback function, which will be called with a
database connection (twisted.enterprise.adbapi.Connection) as
its first argument, followed by `args` and `kwargs`.
- args (list): positional args to pass to `func`
- kwargs (dict): named args to pass to `func`
+ args: positional args to pass to `func`
+ kwargs: named args to pass to `func`
Returns:
Deferred: The result of func
"""
- parent_context = LoggingContext.current_context()
+ parent_context = (
+ LoggingContext.current_context()
+ ) # type: Optional[LoggingContextOrSentinel]
if parent_context == LoggingContext.sentinel:
logger.warning(
"Starting db connection from sentinel context: metrics will be lost"
@@ -800,7 +831,7 @@ class Database(object):
return False
# We didn't find any existing rows, so insert a new one
- allvalues = {}
+ allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
@@ -829,7 +860,7 @@ class Database(object):
Returns:
None
"""
- allvalues = {}
+ allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
allvalues.update(insertion_values)
@@ -916,7 +947,7 @@ class Database(object):
Returns:
None
"""
- allnames = []
+ allnames = [] # type: List[str]
allnames.extend(key_names)
allnames.extend(value_names)
@@ -1100,7 +1131,7 @@ class Database(object):
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
- results = []
+ results = [] # type: List[Dict[str, Any]]
if not iterable:
return results
@@ -1439,7 +1470,7 @@ class Database(object):
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
where_clause = "WHERE " if filters or keyvalues else ""
- arg_list = []
+ arg_list = [] # type: List[Any]
if filters:
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
arg_list += list(filters.values())
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 9d2d519922..035f9ea6e9 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -12,29 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import importlib
import platform
-from ._base import IncorrectDatabaseSetup
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine
-SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
-
-def create_engine(database_config):
+def create_engine(database_config) -> BaseDatabaseEngine:
name = database_config["name"]
- engine_class = SUPPORTED_MODULE.get(name, None)
- if engine_class:
+ if name == "sqlite3":
+ import sqlite3
+
+ return Sqlite3Engine(sqlite3, database_config)
+
+ if name == "psycopg2":
# pypy requires psycopg2cffi rather than psycopg2
- if name == "psycopg2" and platform.python_implementation() == "PyPy":
- name = "psycopg2cffi"
- module = importlib.import_module(name)
- return engine_class(module, database_config)
+ if platform.python_implementation() == "PyPy":
+ import psycopg2cffi as psycopg2 # type: ignore
+ else:
+ import psycopg2 # type: ignore
+
+ return PostgresEngine(psycopg2, database_config)
raise RuntimeError("Unsupported database engine '%s'" % (name,))
-__all__ = ["create_engine", "IncorrectDatabaseSetup"]
+__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ec5a4d198b..ab0bbe4bd3 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -12,7 +12,94 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
+from typing import Generic, TypeVar
+
+from synapse.storage.types import Connection
class IncorrectDatabaseSetup(RuntimeError):
pass
+
+
+ConnectionType = TypeVar("ConnectionType", bound=Connection)
+
+
+class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
+ def __init__(self, module, database_config: dict):
+ self.module = module
+
+ @property
+ @abc.abstractmethod
+ def single_threaded(self) -> bool:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def can_native_upsert(self) -> bool:
+ """
+ Do we support native UPSERTs?
+ """
+ ...
+
+ @property
+ @abc.abstractmethod
+ def supports_tuple_comparison(self) -> bool:
+ """
+ Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
+ """
+ ...
+
+ @property
+ @abc.abstractmethod
+ def supports_using_any_list(self) -> bool:
+ """
+ Do we support using `a = ANY(?)` and passing a list
+ """
+ ...
+
+ @abc.abstractmethod
+ def check_database(
+ self, db_conn: ConnectionType, allow_outdated_version: bool = False
+ ) -> None:
+ ...
+
+ @abc.abstractmethod
+ def check_new_database(self, txn) -> None:
+ """Gets called when setting up a brand new database. This allows us to
+ apply stricter checks on new databases versus existing database.
+ """
+ ...
+
+ @abc.abstractmethod
+ def convert_param_style(self, sql: str) -> str:
+ ...
+
+ @abc.abstractmethod
+ def on_new_connection(self, db_conn: ConnectionType) -> None:
+ ...
+
+ @abc.abstractmethod
+ def is_deadlock(self, error: Exception) -> bool:
+ ...
+
+ @abc.abstractmethod
+ def is_connection_closed(self, conn: ConnectionType) -> bool:
+ ...
+
+ @abc.abstractmethod
+ def lock_table(self, txn, table: str) -> None:
+ ...
+
+ @abc.abstractmethod
+ def get_next_state_group_id(self, txn) -> int:
+ """Returns an int that can be used as a new state_group ID
+ """
+ ...
+
+ @property
+ @abc.abstractmethod
+ def server_version(self) -> str:
+ """Gets a string giving the server version. For example: '3.22.0'
+ """
+ ...
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 53b3f372b0..6c7d08a6f2 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -15,16 +15,14 @@
import logging
-from ._base import IncorrectDatabaseSetup
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
logger = logging.getLogger(__name__)
-class PostgresEngine(object):
- single_threaded = False
-
+class PostgresEngine(BaseDatabaseEngine):
def __init__(self, database_module, database_config):
- self.module = database_module
+ super().__init__(database_module, database_config)
self.module.extensions.register_type(self.module.extensions.UNICODE)
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
@@ -36,6 +34,10 @@ class PostgresEngine(object):
self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet
+ @property
+ def single_threaded(self) -> bool:
+ return False
+
def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 641e490697..3bc2e8b986 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,16 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import struct
import threading
+import typing
+
+from synapse.storage.engines import BaseDatabaseEngine
+if typing.TYPE_CHECKING:
+ import sqlite3 # noqa: F401
-class Sqlite3Engine(object):
- single_threaded = True
+class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
def __init__(self, database_module, database_config):
- self.module = database_module
+ super().__init__(database_module, database_config)
database = database_config.get("args", {}).get("database")
self._is_in_memory = database in (None, ":memory:",)
@@ -32,6 +35,10 @@ class Sqlite3Engine(object):
self._current_state_group_id_lock = threading.Lock()
@property
+ def single_threaded(self) -> bool:
+ return True
+
+ @property
def can_native_upsert(self):
"""
Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
@@ -68,7 +75,6 @@ class Sqlite3Engine(object):
return sql
def on_new_connection(self, db_conn):
-
# We need to import here to avoid an import loop.
from synapse.storage.prepare_database import prepare_database
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
new file mode 100644
index 0000000000..daff81c5ee
--- /dev/null
+++ b/synapse/storage/types.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Iterable, Iterator, List, Tuple
+
+from typing_extensions import Protocol
+
+
+"""
+Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
+"""
+
+
+class Cursor(Protocol):
+ def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
+ ...
+
+ def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
+ ...
+
+ def fetchall(self) -> List[Tuple]:
+ ...
+
+ def fetchone(self) -> Tuple:
+ ...
+
+ @property
+ def description(self) -> Any:
+ return None
+
+ @property
+ def rowcount(self) -> int:
+ return 0
+
+ def __iter__(self) -> Iterator[Tuple]:
+ ...
+
+ def close(self) -> None:
+ ...
+
+
+class Connection(Protocol):
+ def cursor(self) -> Cursor:
+ ...
+
+ def close(self) -> None:
+ ...
+
+ def commit(self) -> None:
+ ...
+
+ def rollback(self, *args, **kwargs) -> None:
+ ...
|