diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index ec89f645d4..8e5d78f6f7 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -17,18 +17,19 @@
"""
The storage layer is split up into multiple parts to allow Synapse to run
against different configurations of databases (e.g. single or multiple
-databases). The `Database` class represents a single physical database. The
-`data_stores` are classes that talk directly to a `Database` instance and have
-associated schemas, background updates, etc. On top of those there are classes
-that provide high level interfaces that combine calls to multiple `data_stores`.
+databases). The `DatabasePool` class represents connections to a single physical
+database. The `databases` are classes that talk directly to a `DatabasePool`
+instance and have associated schemas, background updates, etc. On top of those
+there are classes that provide high level interfaces that combine calls to
+multiple `databases`.
There are also schemas that get applied to every database, regardless of the
data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`.
"""
-from synapse.storage.data_stores import DataStores
-from synapse.storage.data_stores.main import DataStore
+from synapse.storage.databases import Databases
+from synapse.storage.databases.main import DataStore
from synapse.storage.persist_events import EventsPersistenceStorage
from synapse.storage.purge_events import PurgeEventsStorage
from synapse.storage.state import StateGroupStorage
@@ -36,11 +37,11 @@ from synapse.storage.state import StateGroupStorage
__all__ = ["DataStores", "DataStore"]
-class Storage(object):
+class Storage:
"""The high level interfaces for talking to various storage layers.
"""
- def __init__(self, hs, stores: DataStores):
+ def __init__(self, hs, stores: Databases):
# We include the main data store here mainly so that we don't have to
# rewrite all the existing code to split it into high vs low level
# interfaces.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index bfce541ca7..ab49d227de 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -19,12 +19,11 @@ import random
from abc import ABCMeta
from typing import Any, Optional
-from canonicaljson import json
-
from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.types import Collection, get_domain_from_id
+from synapse.util import json_decoder
logger = logging.getLogger(__name__)
@@ -37,11 +36,11 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database).
"""
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
- self.db = database
+ self.db_pool = database
self.rand = random.SystemRandom()
def process_replication_rows(self, stream_name, instance_name, token, rows):
@@ -58,7 +57,6 @@ class SQLBaseStore(metaclass=ABCMeta):
"""
for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
- self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
@@ -100,13 +98,13 @@ def db_to_json(db_content):
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
- # Decode it to a Unicode string before feeding it to json.loads, so we
- # consistenty get a Unicode-containing object out.
+ # Decode it to a Unicode string before feeding it to the JSON decoder, since
+ # Python 3.5 does not support deserializing bytes.
if isinstance(db_content, (bytes, bytearray)):
db_content = db_content.decode("utf8")
try:
- return json.loads(db_content)
+ return json_decoder.decode(db_content)
except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 59f3394b0a..810721ebe9 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -16,18 +16,15 @@
import logging
from typing import Optional
-from canonicaljson import json
-
-from twisted.internet import defer
-
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import json_encoder
from . import engines
logger = logging.getLogger(__name__)
-class BackgroundUpdatePerformance(object):
+class BackgroundUpdatePerformance:
"""Tracks the how long a background update is taking to update its items"""
def __init__(self, name):
@@ -74,7 +71,7 @@ class BackgroundUpdatePerformance(object):
return float(self.total_item_count) / float(self.total_duration_ms)
-class BackgroundUpdater(object):
+class BackgroundUpdater:
""" Background updates are updates to the database that run in the
background. Each update processes a batch of data at once. We attempt to
limit the impact of each update by monitoring how long each batch takes to
@@ -88,7 +85,7 @@ class BackgroundUpdater(object):
def __init__(self, hs, database):
self._clock = hs.get_clock()
- self.db = database
+ self.db_pool = database
# if a background update is currently running, its name.
self._current_background_update = None # type: Optional[str]
@@ -139,7 +136,7 @@ class BackgroundUpdater(object):
# otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates
# itself, but still wants to wait for them to happen.
- updates = await self.db.simple_select_onecol(
+ updates = await self.db_pool.simple_select_onecol(
"background_updates",
keyvalues=None,
retcol="1",
@@ -160,7 +157,7 @@ class BackgroundUpdater(object):
if update_name == self._current_background_update:
return False
- update_exists = await self.db.simple_select_one_onecol(
+ update_exists = await self.db_pool.simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="1",
@@ -189,10 +186,10 @@ class BackgroundUpdater(object):
ORDER BY ordering, update_name
"""
)
- return self.db.cursor_to_dict(txn)
+ return self.db_pool.cursor_to_dict(txn)
if not self._current_background_update:
- all_pending_updates = await self.db.runInteraction(
+ all_pending_updates = await self.db_pool.runInteraction(
"background_updates", get_background_updates_txn,
)
if not all_pending_updates:
@@ -243,13 +240,16 @@ class BackgroundUpdater(object):
else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
- progress_json = await self.db.simple_select_one_onecol(
+ progress_json = await self.db_pool.simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="progress_json",
)
- progress = json.loads(progress_json)
+ # Avoid a circular import.
+ from synapse.storage._base import db_to_json
+
+ progress = db_to_json(progress_json)
time_start = self._clock.time_msec()
items_updated = await update_handler(progress, batch_size)
@@ -305,9 +305,8 @@ class BackgroundUpdater(object):
update_name (str): Name of update
"""
- @defer.inlineCallbacks
- def noop_update(progress, batch_size):
- yield self._end_background_update(update_name)
+ async def noop_update(progress, batch_size):
+ await self._end_background_update(update_name)
return 1
self.register_background_update_handler(update_name, noop_update)
@@ -399,30 +398,30 @@ class BackgroundUpdater(object):
logger.debug("[SQL] %s", sql)
c.execute(sql)
- if isinstance(self.db.engine, engines.PostgresEngine):
+ if isinstance(self.db_pool.engine, engines.PostgresEngine):
runner = create_index_psql
elif psql_only:
runner = None
else:
runner = create_index_sqlite
- @defer.inlineCallbacks
- def updater(progress, batch_size):
+ async def updater(progress, batch_size):
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
- yield self.db.runWithConnection(runner)
- yield self._end_background_update(update_name)
+ await self.db_pool.runWithConnection(runner)
+ await self._end_background_update(update_name)
return 1
self.register_background_update_handler(update_name, updater)
- def _end_background_update(self, update_name):
+ async def _end_background_update(self, update_name: str) -> None:
"""Removes a completed background update task from the queue.
Args:
- update_name(str): The name of the completed task to remove
+ update_name:: The name of the completed task to remove
+
Returns:
- A deferred that completes once the task is removed.
+ None, completes once the task is removed.
"""
if update_name != self._current_background_update:
raise Exception(
@@ -430,11 +429,11 @@ class BackgroundUpdater(object):
% update_name
)
self._current_background_update = None
- return self.db.simple_delete_one(
+ await self.db_pool.simple_delete_one(
"background_updates", keyvalues={"update_name": update_name}
)
- def _background_update_progress(self, update_name: str, progress: dict):
+ async def _background_update_progress(self, update_name: str, progress: dict):
"""Update the progress of a background update
Args:
@@ -442,7 +441,7 @@ class BackgroundUpdater(object):
progress: The progress of the update.
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"background_update_progress",
self._background_update_progress_txn,
update_name,
@@ -458,9 +457,9 @@ class BackgroundUpdater(object):
progress(dict): The progress of the update.
"""
- progress_json = json.dumps(progress)
+ progress_json = json_encoder.encode(progress)
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
"background_updates",
keyvalues={"update_name": update_name},
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
deleted file mode 100644
index 599ee470d4..0000000000
--- a/synapse/storage/data_stores/__init__.py
+++ /dev/null
@@ -1,97 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-
-from synapse.storage.data_stores.main.events import PersistEventsStore
-from synapse.storage.data_stores.state import StateGroupDataStore
-from synapse.storage.database import Database, make_conn
-from synapse.storage.engines import create_engine
-from synapse.storage.prepare_database import prepare_database
-
-logger = logging.getLogger(__name__)
-
-
-class DataStores(object):
- """The various data stores.
-
- These are low level interfaces to physical databases.
-
- Attributes:
- main (DataStore)
- """
-
- def __init__(self, main_store_class, hs):
- # Note we pass in the main store class here as workers use a different main
- # store.
-
- self.databases = []
- self.main = None
- self.state = None
- self.persist_events = None
-
- for database_config in hs.config.database.databases:
- db_name = database_config.name
- engine = create_engine(database_config.config)
-
- with make_conn(database_config, engine) as db_conn:
- logger.info("Preparing database %r...", db_name)
-
- engine.check_database(db_conn)
- prepare_database(
- db_conn, engine, hs.config, data_stores=database_config.data_stores,
- )
-
- database = Database(hs, database_config, engine)
-
- if "main" in database_config.data_stores:
- logger.info("Starting 'main' data store")
-
- # Sanity check we don't try and configure the main store on
- # multiple databases.
- if self.main:
- raise Exception("'main' data store already configured")
-
- self.main = main_store_class(database, db_conn, hs)
-
- # If we're on a process that can persist events also
- # instantiate a `PersistEventsStore`
- if hs.config.worker.writers.events == hs.get_instance_name():
- self.persist_events = PersistEventsStore(
- hs, database, self.main
- )
-
- if "state" in database_config.data_stores:
- logger.info("Starting 'state' data store")
-
- # Sanity check we don't try and configure the state store on
- # multiple databases.
- if self.state:
- raise Exception("'state' data store already configured")
-
- self.state = StateGroupDataStore(database, db_conn, hs)
-
- db_conn.commit()
-
- self.databases.append(database)
-
- logger.info("Database %r prepared", db_name)
-
- # Sanity check that we have actually configured all the required stores.
- if not self.main:
- raise Exception("No 'main' data store configured")
-
- if not self.state:
- raise Exception("No 'main' data store configured")
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b112ff3df2..ed8a9bffb1 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -16,6 +16,7 @@
# limitations under the License.
import logging
import time
+from sys import intern
from time import monotonic as monotonic_time
from typing import (
Any,
@@ -27,15 +28,14 @@ from typing import (
Optional,
Tuple,
TypeVar,
+ cast,
+ overload,
)
-from six import iteritems, iterkeys, itervalues
-from six.moves import intern, range
-
from prometheus_client import Histogram
+from typing_extensions import Literal
from twisted.enterprise import adbapi
-from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -51,11 +51,11 @@ from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3E
from synapse.storage.types import Connection, Cursor
from synapse.types import Collection
-logger = logging.getLogger(__name__)
-
# python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1
+logger = logging.getLogger(__name__)
+
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
perf_logger = logging.getLogger("synapse.storage.TIME")
@@ -127,7 +127,7 @@ class LoggingTransaction:
method.
Args:
- txn: The database transcation object to wrap.
+ txn: The database transaction object to wrap.
name: The name of this transactions for logging.
database_engine
after_callbacks: A list that callbacks will be appended to
@@ -162,7 +162,7 @@ class LoggingTransaction:
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks
- def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
+ def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
@@ -173,7 +173,9 @@ class LoggingTransaction:
assert self.after_callbacks is not None
self.after_callbacks.append((callback, args, kwargs))
- def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+ def call_on_exception(
+ self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
+ ):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
@@ -197,7 +199,7 @@ class LoggingTransaction:
def description(self) -> Any:
return self.txn.description
- def execute_batch(self, sql, args):
+ def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
@@ -206,17 +208,17 @@ class LoggingTransaction:
for val in args:
self.execute(sql, val)
- def execute(self, sql: str, *args: Any):
+ def execute(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.execute, sql, *args)
- def executemany(self, sql: str, *args: Any):
+ def executemany(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.executemany, sql, *args)
def _make_sql_one_line(self, sql: str) -> str:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(line.strip() for line in sql.splitlines() if line.strip())
- def _do_execute(self, func, sql, *args):
+ def _do_execute(self, func, sql: str, *args: Any) -> None:
sql = self._make_sql_one_line(sql)
# TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -235,31 +237,31 @@ class LoggingTransaction:
try:
return func(sql, *args)
except Exception as e:
- logger.debug("[SQL FAIL] {%s} %s", self.name, e)
+ sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
finally:
secs = time.time() - start
sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
sql_query_timer.labels(sql.split()[0]).observe(secs)
- def close(self):
+ def close(self) -> None:
self.txn.close()
-class PerformanceCounters(object):
+class PerformanceCounters:
def __init__(self):
self.current_counters = {}
self.previous_counters = {}
- def update(self, key, duration_secs):
+ def update(self, key: str, duration_secs: float) -> None:
count, cum_time = self.current_counters.get(key, (0, 0))
count += 1
cum_time += duration_secs
self.current_counters[key] = (count, cum_time)
- def interval(self, interval_duration_secs, limit=3):
+ def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
counters = []
- for name, (count, cum_time) in iteritems(self.current_counters):
+ for name, (count, cum_time) in self.current_counters.items():
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append(
(
@@ -281,7 +283,10 @@ class PerformanceCounters(object):
return top_n_counters
-class Database(object):
+R = TypeVar("R")
+
+
+class DatabasePool:
"""Wraps a single physical database and connection pool.
A single database may be used by multiple data stores.
@@ -329,13 +334,12 @@ class Database(object):
self._check_safe_to_upsert,
)
- def is_running(self):
+ def is_running(self) -> bool:
"""Is the database pool currently running
"""
return self._db_pool.running
- @defer.inlineCallbacks
- def _check_safe_to_upsert(self):
+ async def _check_safe_to_upsert(self) -> None:
"""
Is it safe to use native UPSERT?
@@ -344,7 +348,7 @@ class Database(object):
If the background updates have not completed, wait 15 sec and check again.
"""
- updates = yield self.simple_select_list(
+ updates = await self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
@@ -366,7 +370,7 @@ class Database(object):
self._check_safe_to_upsert,
)
- def start_profiling(self):
+ def start_profiling(self) -> None:
self._previous_loop_ts = monotonic_time()
def loop():
@@ -390,8 +394,15 @@ class Database(object):
self._clock.looping_call(loop, 10000)
def new_transaction(
- self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
- ):
+ self,
+ conn: Connection,
+ desc: str,
+ after_callbacks: List[_CallbackListEntry],
+ exception_callbacks: List[_CallbackListEntry],
+ func: "Callable[..., R]",
+ *args: Any,
+ **kwargs: Any
+ ) -> R:
start = monotonic_time()
txn_id = self._TXN_ID
@@ -421,7 +432,7 @@ class Database(object):
except self.engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
- logger.warning(
+ transaction_logger.warning(
"[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
)
if i < N:
@@ -429,18 +440,20 @@ class Database(object):
try:
conn.rollback()
except self.engine.module.Error as e1:
- logger.warning("[TXN EROLL] {%s} %s", name, e1)
+ transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
continue
raise
except self.engine.module.DatabaseError as e:
if self.engine.is_deadlock(e):
- logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+ transaction_logger.warning(
+ "[TXN DEADLOCK] {%s} %d/%d", name, i, N
+ )
if i < N:
i += 1
try:
conn.rollback()
except self.engine.module.Error as e1:
- logger.warning(
+ transaction_logger.warning(
"[TXN EROLL] {%s} %s", name, e1,
)
continue
@@ -480,7 +493,7 @@ class Database(object):
# [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
cursor.close()
except Exception as e:
- logger.debug("[TXN FAIL] {%s} %s", name, e)
+ transaction_logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = monotonic_time()
@@ -494,8 +507,9 @@ class Database(object):
self._txn_perf_counters.update(desc, duration)
sql_txn_timer.labels(desc).observe(duration)
- @defer.inlineCallbacks
- def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
+ async def runInteraction(
+ self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
+ ) -> R:
"""Starts a transaction on the database and runs a given function
Arguments:
@@ -508,7 +522,7 @@ class Database(object):
kwargs: named args to pass to `func`
Returns:
- Deferred: The result of func
+ The result of func
"""
after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] # type: List[_CallbackListEntry]
@@ -517,7 +531,7 @@ class Database(object):
logger.warning("Starting db txn '%s' from sentinel context", desc)
try:
- result = yield self.runWithConnection(
+ result = await self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
@@ -534,10 +548,11 @@ class Database(object):
after_callback(*after_args, **after_kwargs)
raise
- return result
+ return cast(R, result)
- @defer.inlineCallbacks
- def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
+ async def runWithConnection(
+ self, func: "Callable[..., R]", *args: Any, **kwargs: Any
+ ) -> R:
"""Wraps the .runWithConnection() method on the underlying db_pool.
Arguments:
@@ -548,7 +563,7 @@ class Database(object):
kwargs: named args to pass to `func`
Returns:
- Deferred: The result of func
+ The result of func
"""
parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
if not parent_context:
@@ -571,18 +586,16 @@ class Database(object):
return func(conn, *args, **kwargs)
- result = yield make_deferred_yieldable(
+ return await make_deferred_yieldable(
self._db_pool.runWithConnection(inner_func, *args, **kwargs)
)
- return result
-
@staticmethod
- def cursor_to_dict(cursor):
+ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
"""Converts a SQL cursor into an list of dicts.
Args:
- cursor : The DBAPI cursor which has executed a query.
+ cursor: The DBAPI cursor which has executed a query.
Returns:
A list of dicts where the key is the column header.
"""
@@ -590,10 +603,29 @@ class Database(object):
results = [dict(zip(col_headers, row)) for row in cursor]
return results
- def execute(self, desc, decoder, query, *args):
+ @overload
+ async def execute(
+ self, desc: str, decoder: Literal[None], query: str, *args: Any
+ ) -> List[Tuple[Any, ...]]:
+ ...
+
+ @overload
+ async def execute(
+ self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
+ ) -> R:
+ ...
+
+ async def execute(
+ self,
+ desc: str,
+ decoder: Optional[Callable[[Cursor], R]],
+ query: str,
+ *args: Any
+ ) -> R:
"""Runs a single query for a result set.
Args:
+ desc: description of the transaction, for logging and metrics
decoder - The function which can resolve the cursor results to
something meaningful.
query - The query string to execute
@@ -609,29 +641,33 @@ class Database(object):
else:
return txn.fetchall()
- return self.runInteraction(desc, interaction)
+ return await self.runInteraction(desc, interaction)
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
- @defer.inlineCallbacks
- def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+ async def simple_insert(
+ self,
+ table: str,
+ values: Dict[str, Any],
+ or_ignore: bool = False,
+ desc: str = "simple_insert",
+ ) -> bool:
"""Executes an INSERT query on the named table.
Args:
- table : string giving the table name
- values : dict of new column names and values for them
- or_ignore : bool stating whether an exception should be raised
+ table: string giving the table name
+ values: dict of new column names and values for them
+ or_ignore: bool stating whether an exception should be raised
when a conflicting row already exists. If True, False will be
returned by the function instead
- desc : string giving a description of the transaction
+ desc: description of the transaction, for logging and metrics
Returns:
- bool: Whether the row was inserted or not. Only useful when
- `or_ignore` is True
+ Whether the row was inserted or not. Only useful when `or_ignore` is True
"""
try:
- yield self.runInteraction(desc, self.simple_insert_txn, table, values)
+ await self.runInteraction(desc, self.simple_insert_txn, table, values)
except self.engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
@@ -641,7 +677,9 @@ class Database(object):
return True
@staticmethod
- def simple_insert_txn(txn, table, values):
+ def simple_insert_txn(
+ txn: LoggingTransaction, table: str, values: Dict[str, Any]
+ ) -> None:
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -652,11 +690,29 @@ class Database(object):
txn.execute(sql, vals)
- def simple_insert_many(self, table, values, desc):
- return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
+ async def simple_insert_many(
+ self, table: str, values: List[Dict[str, Any]], desc: str
+ ) -> None:
+ """Executes an INSERT query on the named table.
+
+ Args:
+ table: string giving the table name
+ values: dict of new column names and values for them
+ desc: description of the transaction, for logging and metrics
+ """
+ await self.runInteraction(desc, self.simple_insert_many_txn, table, values)
@staticmethod
- def simple_insert_many_txn(txn, table, values):
+ def simple_insert_many_txn(
+ txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
+ ) -> None:
+ """Executes an INSERT query on the named table.
+
+ Args:
+ txn: The transaction to use.
+ table: string giving the table name
+ values: dict of new column names and values for them
+ """
if not values:
return
@@ -684,16 +740,15 @@ class Database(object):
txn.executemany(sql, vals)
- @defer.inlineCallbacks
- def simple_upsert(
+ async def simple_upsert(
self,
- table,
- keyvalues,
- values,
- insertion_values={},
- desc="simple_upsert",
- lock=True,
- ):
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ desc: str = "simple_upsert",
+ lock: bool = True,
+ ) -> Optional[bool]:
"""
`lock` should generally be set to True (the default), but can be set
@@ -707,21 +762,20 @@ class Database(object):
this table.
Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key columns and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
+ table: The table to upsert into
+ keyvalues: The unique key columns and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
+ desc: description of the transaction, for logging and metrics
+ lock: True to lock the table when doing the upsert.
Returns:
- Deferred(None or bool): Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
+ Native upserts always return None. Emulated upserts return True if a
+ new entry was created, False if an existing one was updated.
"""
attempts = 0
while True:
try:
- result = yield self.runInteraction(
+ return await self.runInteraction(
desc,
self.simple_upsert_txn,
table,
@@ -730,7 +784,6 @@ class Database(object):
insertion_values,
lock=lock,
)
- return result
except self.engine.module.IntegrityError as e:
attempts += 1
if attempts >= 5:
@@ -744,29 +797,34 @@ class Database(object):
)
def simple_upsert_txn(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ lock: bool = True,
+ ) -> Optional[bool]:
"""
Pick the UPSERT method which works best on the platform. Either the
native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
Args:
txn: The transaction to use.
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
+ table: The table to upsert into
+ keyvalues: The unique key tables and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
+ lock: True to lock the table when doing the upsert.
Returns:
- None or bool: Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
+ Native upserts always return None. Emulated upserts return True if a
+ new entry was created, False if an existing one was updated.
"""
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
- return self.simple_upsert_txn_native_upsert(
+ self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
)
+ return None
else:
return self.simple_upsert_txn_emulated(
txn,
@@ -778,18 +836,23 @@ class Database(object):
)
def simple_upsert_txn_emulated(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ lock: bool = True,
+ ) -> bool:
"""
Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
+ table: The table to upsert into
+ keyvalues: The unique key tables and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
+ lock: True to lock the table when doing the upsert.
Returns:
- bool: Return True if a new entry was created, False if an existing
+ Returns True if a new entry was created, False if an existing
one was updated.
"""
# We need to lock the table :(, unless we're *really* careful
@@ -847,19 +910,21 @@ class Database(object):
return True
def simple_upsert_txn_native_upsert(
- self, txn, table, keyvalues, values, insertion_values={}
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ values: Dict[str, Any],
+ insertion_values: Dict[str, Any] = {},
+ ) -> None:
"""
Use the native UPSERT functionality in recent PostgreSQL versions.
Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- Returns:
- None
+ table: The table to upsert into
+ keyvalues: The unique key tables and their new values
+ values: The nonunique columns and their new values
+ insertion_values: additional key/values to use only when inserting
"""
allvalues = {} # type: Dict[str, Any]
allvalues.update(keyvalues)
@@ -989,41 +1054,93 @@ class Database(object):
return txn.execute_batch(sql, args)
- def simple_select_one(
- self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
- ):
+ @overload
+ async def simple_select_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: Literal[False] = False,
+ desc: str = "simple_select_one",
+ ) -> Dict[str, Any]:
+ ...
+
+ @overload
+ async def simple_select_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: Literal[True] = True,
+ desc: str = "simple_select_one",
+ ) -> Optional[Dict[str, Any]]:
+ ...
+
+ async def simple_select_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: bool = False,
+ desc: str = "simple_select_one",
+ ) -> Optional[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcols : list of strings giving the names of the columns to return
-
- allow_none : If true, return None instead of failing if the SELECT
- statement returns no rows
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ retcols: list of strings giving the names of the columns to return
+ allow_none: If true, return None instead of failing if the SELECT
+ statement returns no rows
+ desc: description of the transaction, for logging and metrics
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)
- def simple_select_one_onecol(
+ @overload
+ async def simple_select_one_onecol(
self,
- table,
- keyvalues,
- retcol,
- allow_none=False,
- desc="simple_select_one_onecol",
- ):
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: str,
+ allow_none: Literal[False] = False,
+ desc: str = "simple_select_one_onecol",
+ ) -> Any:
+ ...
+
+ @overload
+ async def simple_select_one_onecol(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: str,
+ allow_none: Literal[True] = True,
+ desc: str = "simple_select_one_onecol",
+ ) -> Optional[Any]:
+ ...
+
+ async def simple_select_one_onecol(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: str,
+ allow_none: bool = False,
+ desc: str = "simple_select_one_onecol",
+ ) -> Optional[Any]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcol : string giving the name of the column to return
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ retcol: string giving the name of the column to return
+ allow_none: If true, return None instead of failing if the SELECT
+ statement returns no rows
+ desc: description of the transaction, for logging and metrics
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc,
self.simple_select_one_onecol_txn,
table,
@@ -1032,10 +1149,39 @@ class Database(object):
allow_none=allow_none,
)
+ @overload
@classmethod
def simple_select_one_onecol_txn(
- cls, txn, table, keyvalues, retcol, allow_none=False
- ):
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: str,
+ allow_none: Literal[False] = False,
+ ) -> Any:
+ ...
+
+ @overload
+ @classmethod
+ def simple_select_one_onecol_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: str,
+ allow_none: Literal[True] = True,
+ ) -> Optional[Any]:
+ ...
+
+ @classmethod
+ def simple_select_one_onecol_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcol: str,
+ allow_none: bool = False,
+ ) -> Optional[Any]:
ret = cls.simple_select_onecol_txn(
txn, table=table, keyvalues=keyvalues, retcol=retcol
)
@@ -1049,64 +1195,85 @@ class Database(object):
raise StoreError(404, "No row found")
@staticmethod
- def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+ def simple_select_onecol_txn(
+ txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str,
+ ) -> List[Any]:
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
if keyvalues:
- sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
txn.execute(sql, list(keyvalues.values()))
else:
txn.execute(sql)
return [r[0] for r in txn]
- def simple_select_onecol(
- self, table, keyvalues, retcol, desc="simple_select_onecol"
- ):
+ async def simple_select_onecol(
+ self,
+ table: str,
+ keyvalues: Optional[Dict[str, Any]],
+ retcol: str,
+ desc: str = "simple_select_onecol",
+ ) -> List[Any]:
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
Args:
- table (str): table name
- keyvalues (dict|None): column names and values to select the rows with
- retcol (str): column whos value we wish to retrieve.
+ table: table name
+ keyvalues: column names and values to select the rows with
+ retcol: column whos value we wish to retrieve.
+ desc: description of the transaction, for logging and metrics
Returns:
- Deferred: Results in a list
+ Results in a list
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc, self.simple_select_onecol_txn, table, keyvalues, retcol
)
- def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+ async def simple_select_list(
+ self,
+ table: str,
+ keyvalues: Optional[Dict[str, Any]],
+ retcols: Iterable[str],
+ desc: str = "simple_select_list",
+ ) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- table (str): the table name
- keyvalues (dict[str, Any] | None):
+ table: the table name
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
+ retcols: the names of the columns to return
+ desc: description of the transaction, for logging and metrics
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ A list of dictionaries.
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc, self.simple_select_list_txn, table, keyvalues, retcols
)
@classmethod
- def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+ def simple_select_list_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Optional[Dict[str, Any]],
+ retcols: Iterable[str],
+ ) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- txn : Transaction object
- table (str): the table name
- keyvalues (dict[str, T] | None):
+ txn: Transaction object
+ table: the table name
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
+ retcols: the names of the columns to return
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1121,28 +1288,29 @@ class Database(object):
return cls.cursor_to_dict(txn)
- @defer.inlineCallbacks
- def simple_select_many_batch(
+ async def simple_select_many_batch(
self,
- table,
- column,
- iterable,
- retcols,
- keyvalues={},
- desc="simple_select_many_batch",
- batch_size=100,
- ):
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ retcols: Iterable[str],
+ keyvalues: Dict[str, Any] = {},
+ desc: str = "simple_select_many_batch",
+ batch_size: int = 100,
+ ) -> List[Any]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
- Filters rows by if value of `column` is in `iterable`.
+ Filters rows by whether the value of `column` is in `iterable`.
Args:
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ retcols: list of strings giving the names of the columns to return
+ keyvalues: dict of column names and values to select the rows with
+ desc: description of the transaction, for logging and metrics
+ batch_size: the number of rows for each select query
"""
results = [] # type: List[Dict[str, Any]]
@@ -1156,7 +1324,7 @@ class Database(object):
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
- rows = yield self.runInteraction(
+ rows = await self.runInteraction(
desc,
self.simple_select_many_txn,
table,
@@ -1171,19 +1339,27 @@ class Database(object):
return results
@classmethod
- def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+ def simple_select_many_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ ) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
- Filters rows by if value of `column` is in `iterable`.
+ Filters rows by whether the value of `column` is in `iterable`.
Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
+ txn: Transaction object
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ keyvalues: dict of column names and values to select the rows with
+ retcols: list of strings giving the names of the columns to return
"""
if not iterable:
return []
@@ -1191,7 +1367,7 @@ class Database(object):
clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
clauses = [clause]
- for key, value in iteritems(keyvalues):
+ for key, value in keyvalues.items():
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -1204,15 +1380,26 @@ class Database(object):
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
- def simple_update(self, table, keyvalues, updatevalues, desc):
- return self.runInteraction(
+ async def simple_update(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ desc: str,
+ ) -> int:
+ return await self.runInteraction(
desc, self.simple_update_txn, table, keyvalues, updatevalues
)
@staticmethod
- def simple_update_txn(txn, table, keyvalues, updatevalues):
+ def simple_update_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ ) -> int:
if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
else:
where = ""
@@ -1226,32 +1413,34 @@ class Database(object):
return txn.rowcount
- def simple_update_one(
- self, table, keyvalues, updatevalues, desc="simple_update_one"
- ):
+ async def simple_update_one(
+ self,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ desc: str = "simple_update_one",
+ ) -> None:
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- updatevalues : dict giving column names and values to update
- retcols : optional list of column names to return
-
- If present, retcols gives a list of column names on which to perform
- a SELECT statement *before* performing the UPDATE statement. The values
- of these will be returned in a dict.
-
- These are performed within the same transaction, allowing an atomic
- get-and-set. This can be used to implement compare-and-set by putting
- the update column in the 'keyvalues' dict as well.
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ updatevalues: dict giving column names and values to update
+ desc: description of the transaction, for logging and metrics
"""
- return self.runInteraction(
+ await self.runInteraction(
desc, self.simple_update_one_txn, table, keyvalues, updatevalues
)
@classmethod
- def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+ def simple_update_one_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ updatevalues: Dict[str, Any],
+ ) -> None:
rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
if rowcount == 0:
@@ -1259,8 +1448,18 @@ class Database(object):
if rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
+ # Ideally we could use the overload decorator here to specify that the
+ # return type is only optional if allow_none is True, but this does not work
+ # when you call a static method from an instance.
+ # See https://github.com/python/mypy/issues/7781
@staticmethod
- def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+ def simple_select_one_txn(
+ txn: LoggingTransaction,
+ table: str,
+ keyvalues: Dict[str, Any],
+ retcols: Iterable[str],
+ allow_none: bool = False,
+ ) -> Optional[Dict[str, Any]]:
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
@@ -1279,24 +1478,29 @@ class Database(object):
return dict(zip(retcols, row))
- def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+ async def simple_delete_one(
+ self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
+ ) -> None:
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ desc: description of the transaction, for logging and metrics
"""
- return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
+ await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
@staticmethod
- def simple_delete_one_txn(txn, table, keyvalues):
+ def simple_delete_one_txn(
+ txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+ ) -> None:
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
@@ -1309,11 +1513,38 @@ class Database(object):
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
- def simple_delete(self, table, keyvalues, desc):
- return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
+ async def simple_delete(
+ self, table: str, keyvalues: Dict[str, Any], desc: str
+ ) -> int:
+ """Executes a DELETE query on the named table.
+
+ Filters rows by the key-value pairs.
+
+ Args:
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+ desc: description of the transaction, for logging and metrics
+
+ Returns:
+ The number of deleted rows.
+ """
+ return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
@staticmethod
- def simple_delete_txn(txn, table, keyvalues):
+ def simple_delete_txn(
+ txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+ ) -> int:
+ """Executes a DELETE query on the named table.
+
+ Filters rows by the key-value pairs.
+
+ Args:
+ table: string giving the table name
+ keyvalues: dict of column names and values to select the row with
+
+ Returns:
+ The number of deleted rows.
+ """
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k,) for k in keyvalues),
@@ -1322,26 +1553,53 @@ class Database(object):
txn.execute(sql, list(keyvalues.values()))
return txn.rowcount
- def simple_delete_many(self, table, column, iterable, keyvalues, desc):
- return self.runInteraction(
+ async def simple_delete_many(
+ self,
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ keyvalues: Dict[str, Any],
+ desc: str,
+ ) -> int:
+ """Executes a DELETE query on the named table.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ keyvalues: dict of column names and values to select the rows with
+ desc: description of the transaction, for logging and metrics
+
+ Returns:
+ Number rows deleted
+ """
+ return await self.runInteraction(
desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
)
@staticmethod
- def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+ def simple_delete_many_txn(
+ txn: LoggingTransaction,
+ table: str,
+ column: str,
+ iterable: Iterable[Any],
+ keyvalues: Dict[str, Any],
+ ) -> int:
"""Executes a DELETE query on the named table.
Filters rows by if value of `column` is in `iterable`.
Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
+ txn: Transaction object
+ table: string giving the table name
+ column: column name to test for inclusion against `iterable`
+ iterable: list
+ keyvalues: dict of column names and values to select the rows with
Returns:
- int: Number rows deleted
+ Number rows deleted
"""
if not iterable:
return 0
@@ -1351,7 +1609,7 @@ class Database(object):
clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
clauses = [clause]
- for key, value in iteritems(keyvalues):
+ for key, value in keyvalues.items():
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -1362,8 +1620,14 @@ class Database(object):
return txn.rowcount
def get_cache_dict(
- self, db_conn, table, entity_column, stream_column, max_value, limit=100000
- ):
+ self,
+ db_conn: Connection,
+ table: str,
+ entity_column: str,
+ stream_column: str,
+ max_value: int,
+ limit: int = 100000,
+ ) -> Tuple[Dict[Any, int], int]:
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
@@ -1388,71 +1652,25 @@ class Database(object):
txn.close()
if cache:
- min_val = min(itervalues(cache))
+ min_val = min(cache.values())
else:
min_val = max_value
return cache, min_val
- def simple_select_list_paginate(
- self,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=None,
- keyvalues=None,
- order_direction="ASC",
- desc="simple_select_list_paginate",
- ):
- """
- Executes a SELECT query on the named table with start and limit,
- of row numbers, which may return zero or number of rows from start to limit,
- returning the result as a list of dicts.
-
- Args:
- table (str): the table name
- filters (dict[str, T] | None):
- column names and values to filter the rows with, or None to not
- apply a WHERE ? LIKE ? clause.
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- return self.runInteraction(
- desc,
- self.simple_select_list_paginate_txn,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=filters,
- keyvalues=keyvalues,
- order_direction=order_direction,
- )
-
@classmethod
def simple_select_list_paginate_txn(
cls,
- txn,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=None,
- keyvalues=None,
- order_direction="ASC",
- ):
+ txn: LoggingTransaction,
+ table: str,
+ orderby: str,
+ start: int,
+ limit: int,
+ retcols: Iterable[str],
+ filters: Optional[Dict[str, Any]] = None,
+ keyvalues: Optional[Dict[str, Any]] = None,
+ order_direction: str = "ASC",
+ ) -> List[Dict[str, Any]]:
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
@@ -1463,21 +1681,22 @@ class Database(object):
using 'AND'.
Args:
- txn : Transaction object
- table (str): the table name
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- filters (dict[str, T] | None):
+ txn: Transaction object
+ table: the table name
+ orderby: Column to order the results by.
+ start: Index to begin the query at.
+ limit: Number of results to return.
+ retcols: the names of the columns to return
+ filters:
column names and values to filter the rows with, or None to not
apply a WHERE ? LIKE ? clause.
- keyvalues (dict[str, T] | None):
+ keyvalues:
column names and values to select the rows with, or None to not
apply a WHERE clause.
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+ order_direction: Whether the results should be ordered "ASC" or "DESC".
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ The result as a list of dictionaries.
"""
if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
@@ -1503,51 +1722,65 @@ class Database(object):
return cls.cursor_to_dict(txn)
- def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+ async def simple_search_list(
+ self,
+ table: str,
+ term: Optional[str],
+ col: str,
+ retcols: Iterable[str],
+ desc="simple_search_list",
+ ) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
+ table: the table name
+ term: term for searching the table matched to a column.
+ col: column to query term should be matched to
+ retcols: the names of the columns to return
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
+ A list of dictionaries or None.
"""
- return self.runInteraction(
+ return await self.runInteraction(
desc, self.simple_search_list_txn, table, term, col, retcols
)
@classmethod
- def simple_search_list_txn(cls, txn, table, term, col, retcols):
+ def simple_search_list_txn(
+ cls,
+ txn: LoggingTransaction,
+ table: str,
+ term: Optional[str],
+ col: str,
+ retcols: Iterable[str],
+ ) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
- txn : Transaction object
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
+ txn: Transaction object
+ table: the table name
+ term: term for searching the table matched to a column.
+ col: column to query term should be matched to
+ retcols: the names of the columns to return
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
+ None if no term is given, otherwise a list of dictionaries.
"""
if term:
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
termvalues = ["%%" + term + "%%"]
txn.execute(sql, termvalues)
else:
- return 0
+ return None
return cls.cursor_to_dict(txn)
def make_in_list_sql_clause(
- database_engine, column: str, iterable: Iterable
+ database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given column is in the iterable.
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
new file mode 100644
index 0000000000..985b12df91
--- /dev/null
+++ b/synapse/storage/databases/__init__.py
@@ -0,0 +1,119 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.databases.main.events import PersistEventsStore
+from synapse.storage.databases.state import StateGroupDataStore
+from synapse.storage.engines import create_engine
+from synapse.storage.prepare_database import prepare_database
+
+logger = logging.getLogger(__name__)
+
+
+class Databases:
+ """The various databases.
+
+ These are low level interfaces to physical databases.
+
+ Attributes:
+ main (DataStore)
+ """
+
+ def __init__(self, main_store_class, hs):
+ # Note we pass in the main store class here as workers use a different main
+ # store.
+
+ self.databases = []
+ main = None
+ state = None
+ persist_events = None
+
+ for database_config in hs.config.database.databases:
+ db_name = database_config.name
+ engine = create_engine(database_config.config)
+
+ with make_conn(database_config, engine) as db_conn:
+ logger.info("[database config %r]: Checking database server", db_name)
+ engine.check_database(db_conn)
+
+ logger.info(
+ "[database config %r]: Preparing for databases %r",
+ db_name,
+ database_config.databases,
+ )
+ prepare_database(
+ db_conn, engine, hs.config, databases=database_config.databases,
+ )
+
+ database = DatabasePool(hs, database_config, engine)
+
+ if "main" in database_config.databases:
+ logger.info(
+ "[database config %r]: Starting 'main' database", db_name
+ )
+
+ # Sanity check we don't try and configure the main store on
+ # multiple databases.
+ if main:
+ raise Exception("'main' data store already configured")
+
+ main = main_store_class(database, db_conn, hs)
+
+ # If we're on a process that can persist events also
+ # instantiate a `PersistEventsStore`
+ if hs.config.worker.writers.events == hs.get_instance_name():
+ persist_events = PersistEventsStore(hs, database, main)
+
+ if "state" in database_config.databases:
+ logger.info(
+ "[database config %r]: Starting 'state' database", db_name
+ )
+
+ # Sanity check we don't try and configure the state store on
+ # multiple databases.
+ if state:
+ raise Exception("'state' data store already configured")
+
+ state = StateGroupDataStore(database, db_conn, hs)
+
+ db_conn.commit()
+
+ self.databases.append(database)
+
+ logger.info("[database config %r]: prepared", db_name)
+
+ # Closing the context manager doesn't close the connection.
+ # psycopg will close the connection when the object gets GCed, but *only*
+ # if the PID is the same as when the connection was opened [1], and
+ # it may not be if we fork in the meantime.
+ #
+ # [1]: https://github.com/psycopg/psycopg2/blob/2_8_5/psycopg/connection_type.c#L1378
+
+ db_conn.close()
+
+ # Sanity check that we have actually configured all the required stores.
+ if not main:
+ raise Exception("No 'main' database configured")
+
+ if not state:
+ raise Exception("No 'state' database configured")
+
+ # We use local variables here to ensure that the databases do not have
+ # optional types.
+ self.main = main
+ self.state = state
+ self.persist_events = persist_events
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 4b4763c701..2ae2fbd5d7 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -18,16 +18,18 @@
import calendar
import logging
import time
+from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
IdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
+from synapse.types import get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore
@@ -119,7 +121,7 @@ class DataStore(
CacheInvalidationWorkerStore,
ServerMetricsStore,
):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
@@ -128,7 +130,7 @@ class DataStore(
db_conn, "presence_stream", "stream_id"
)
self._device_inbox_id_gen = StreamIdGenerator(
- db_conn, "device_max_stream_id", "stream_id"
+ db_conn, "device_inbox", "stream_id"
)
self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id"
@@ -174,7 +176,7 @@ class DataStore(
self._presence_on_startup = self._get_active_presence(db_conn)
- presence_cache_prefill, min_presence_val = self.db.get_cache_dict(
+ presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
db_conn,
"presence_stream",
entity_column="user_id",
@@ -188,7 +190,7 @@ class DataStore(
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
- device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict(
+ device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
db_conn,
"device_inbox",
entity_column="user_id",
@@ -203,7 +205,7 @@ class DataStore(
)
# The federation outbox and the local device inbox uses the same
# stream_id generator.
- device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict(
+ device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
db_conn,
"device_federation_outbox",
entity_column="destination",
@@ -229,7 +231,7 @@ class DataStore(
)
events_max = self._stream_id_gen.get_current_token()
- curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
+ curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",
@@ -243,7 +245,7 @@ class DataStore(
prefilled_cache=curr_state_delta_prefill,
)
- _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict(
+ _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict(
db_conn,
"local_group_updates",
entity_column="user_id",
@@ -263,6 +265,9 @@ class DataStore(
# Used in _generate_user_daily_visits to keep track of progress
self._last_user_visit_update = self._get_start_of_day()
+ def get_device_stream_token(self) -> int:
+ return self._device_list_id_gen.get_current_token()
+
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
@@ -282,7 +287,7 @@ class DataStore(
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
txn.close()
for row in rows:
@@ -290,14 +295,16 @@ class DataStore(
return [UserPresenceState(**row) for row in rows]
- def count_daily_users(self):
+ async def count_daily_users(self) -> int:
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
- return self.db.runInteraction("count_daily_users", self._count_users, yesterday)
+ return await self.db_pool.runInteraction(
+ "count_daily_users", self._count_users, yesterday
+ )
- def count_monthly_users(self):
+ async def count_monthly_users(self) -> int:
"""
Counts the number of users who used this homeserver in the last 30 days.
Note this method is intended for phonehome metrics only and is different
@@ -305,7 +312,7 @@ class DataStore(
amongst other things, includes a 3 day grace period before a user counts.
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"count_monthly_users", self._count_users, thirty_days_ago
)
@@ -324,15 +331,15 @@ class DataStore(
(count,) = txn.fetchone()
return count
- def count_r30_users(self):
+ async def count_r30_users(self) -> Dict[str, int]:
"""
Counts the number of 30 day retained users, defined as:-
* Users who have created their accounts more than 30 days ago
* Where last seen at most 30 days ago
* Where account creation and last_seen are > 30 days apart
- Returns counts globaly for a given user as well as breaking
- by platform
+ Returns:
+ A mapping of counts globally as well as broken out by platform.
"""
def _count_r30_users(txn):
@@ -405,7 +412,7 @@ class DataStore(
return results
- return self.db.runInteraction("count_r30_users", _count_r30_users)
+ return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
def _get_start_of_day(self):
"""
@@ -415,7 +422,7 @@ class DataStore(
today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
return today_start * 1000
- def generate_user_daily_visits(self):
+ async def generate_user_daily_visits(self) -> None:
"""
Generates daily visit data for use in cohort/ retention analysis
"""
@@ -470,18 +477,17 @@ class DataStore(
# frequently
self._last_user_visit_update = now
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"generate_user_daily_visits", _generate_user_daily_visits
)
- def get_users(self):
+ async def get_users(self) -> List[Dict[str, Any]]:
"""Function to retrieve a list of users in users table.
- Args:
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ A list of dictionaries representing users.
"""
- return self.db.simple_select_list(
+ return await self.db_pool.simple_select_list(
table="users",
keyvalues={},
retcols=[
@@ -495,30 +501,40 @@ class DataStore(
desc="get_users",
)
- def get_users_paginate(
- self, start, limit, name=None, guests=True, deactivated=False
- ):
+ async def get_users_paginate(
+ self,
+ start: int,
+ limit: int,
+ user_id: Optional[str] = None,
+ name: Optional[str] = None,
+ guests: bool = True,
+ deactivated: bool = False,
+ ) -> Tuple[List[Dict[str, Any]], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
total number of users matching the filter criteria.
Args:
- start (int): start number to begin the query from
- limit (int): number of rows to retrieve
- name (string): filter for user names
- guests (bool): whether to in include guest users
- deactivated (bool): whether to include deactivated users
+ start: start number to begin the query from
+ limit: number of rows to retrieve
+ user_id: search for user_id. ignored if name is not None
+ name: search for local part of user_id or display name
+ guests: whether to in include guest users
+ deactivated: whether to include deactivated users
Returns:
- defer.Deferred: resolves to list[dict[str, Any]], int
+ A tuple of a list of mappings from user to information and a count of total users.
"""
def get_users_paginate_txn(txn):
filters = []
- args = []
+ args = [self.hs.config.server_name]
if name:
+ filters.append("(name LIKE ? OR displayname LIKE ?)")
+ args.extend(["@%" + name + "%:%", "%" + name + "%"])
+ elif user_id:
filters.append("name LIKE ?")
- args.append("%" + name + "%")
+ args.extend(["%" + user_id + "%"])
if not guests:
filters.append("is_guest = 0")
@@ -528,37 +544,42 @@ class DataStore(
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
- sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
- txn.execute(sql, args)
- count = txn.fetchone()[0]
-
- args = [self.hs.config.server_name] + args + [limit, start]
- sql = """
- SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
+ sql_base = """
FROM users as u
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
{}
- ORDER BY u.name LIMIT ? OFFSET ?
""".format(
where_clause
)
+ sql = "SELECT COUNT(*) as total_users " + sql_base
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = (
+ "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
+ + sql_base
+ + " ORDER BY u.name LIMIT ? OFFSET ?"
+ )
+ args += [limit, start]
txn.execute(sql, args)
- users = self.db.cursor_to_dict(txn)
+ users = self.db_pool.cursor_to_dict(txn)
return users, count
- return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn)
+ return await self.db_pool.runInteraction(
+ "get_users_paginate_txn", get_users_paginate_txn
+ )
- def search_users(self, term):
+ async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
"""Function to search users list for one or more users with
the matched term.
Args:
- term (str): search term
- col (str): column to query term should be matched to
+ term: search term
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ A list of dictionaries or None.
"""
- return self.db.simple_search_list(
+ return await self.db_pool.simple_search_list(
table="users",
term=term,
col="name",
@@ -571,21 +592,24 @@ def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig
"""Called before upgrading an existing database to check that it is broadly sane
compared with the configuration.
"""
- domain = config.server_name
+ logger.info("Checking database for consistency with configuration...")
- sql = database_engine.convert_param_style(
- "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
- )
- pat = "%:" + domain
- cur.execute(sql, (pat,))
- num_not_matching = cur.fetchall()[0][0]
- if num_not_matching == 0:
+ # if there are any users in the database, check that the username matches our
+ # configured server name.
+
+ cur.execute("SELECT name FROM users LIMIT 1")
+ rows = cur.fetchall()
+ if not rows:
+ return
+
+ user_domain = get_domain_from_id(rows[0][0])
+ if user_domain == config.server_name:
return
raise Exception(
"Found users in database not native to %s!\n"
- "You cannot changed a synapse server_name after it's been configured"
- % (domain,)
+ "You cannot change a synapse server_name after it's been configured"
+ % (config.server_name,)
)
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/databases/main/account_data.py
index b58f04d00d..4436b1a83d 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -16,16 +16,14 @@
import abc
import logging
-from typing import List, Tuple
+from typing import Dict, List, Optional, Tuple
-from canonicaljson import json
-
-from twisted.internet import defer
-
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.types import JsonDict
+from synapse.util import json_encoder
+from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -40,7 +38,7 @@ class AccountDataWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max
@@ -58,18 +56,20 @@ class AccountDataWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cached()
- def get_account_data_for_user(self, user_id):
+ async def get_account_data_for_user(
+ self, user_id: str
+ ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a user.
Args:
- user_id(str): The user to get the account_data for.
+ user_id: The user to get the account_data for.
Returns:
- A deferred pair of a dict of global account_data and a dict
- mapping from room_id string to per room account_data dicts.
+ A 2-tuple of a dict of global account_data and a dict mapping from
+ room_id string to per room account_data dicts.
"""
def get_account_data_for_user_txn(txn):
- rows = self.db.simple_select_list_txn(
+ rows = self.db_pool.simple_select_list_txn(
txn,
"account_data",
{"user_id": user_id},
@@ -77,10 +77,10 @@ class AccountDataWorkerStore(SQLBaseStore):
)
global_account_data = {
- row["account_data_type"]: json.loads(row["content"]) for row in rows
+ row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
- rows = self.db.simple_select_list_txn(
+ rows = self.db_pool.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id},
@@ -90,21 +90,23 @@ class AccountDataWorkerStore(SQLBaseStore):
by_room = {}
for row in rows:
room_data = by_room.setdefault(row["room_id"], {})
- room_data[row["account_data_type"]] = json.loads(row["content"])
+ room_data[row["account_data_type"]] = db_to_json(row["content"])
return global_account_data, by_room
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
- @cachedInlineCallbacks(num_args=2, max_entries=5000)
- def get_global_account_data_by_type_for_user(self, data_type, user_id):
+ @cached(num_args=2, max_entries=5000)
+ async def get_global_account_data_by_type_for_user(
+ self, data_type: str, user_id: str
+ ) -> Optional[JsonDict]:
"""
Returns:
- Deferred: A dict
+ The account data.
"""
- result = yield self.db.simple_select_one_onecol(
+ result = await self.db_pool.simple_select_one_onecol(
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": data_type},
retcol="content",
@@ -113,23 +115,25 @@ class AccountDataWorkerStore(SQLBaseStore):
)
if result:
- return json.loads(result)
+ return db_to_json(result)
else:
return None
@cached(num_args=2)
- def get_account_data_for_room(self, user_id, room_id):
+ async def get_account_data_for_room(
+ self, user_id: str, room_id: str
+ ) -> Dict[str, JsonDict]:
"""Get all the client account_data for a user for a room.
Args:
- user_id(str): The user to get the account_data for.
- room_id(str): The room to get the account_data for.
+ user_id: The user to get the account_data for.
+ room_id: The room to get the account_data for.
Returns:
- A deferred dict of the room account_data
+ A dict of the room account_data
"""
def get_account_data_for_room_txn(txn):
- rows = self.db.simple_select_list_txn(
+ rows = self.db_pool.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id, "room_id": room_id},
@@ -137,28 +141,29 @@ class AccountDataWorkerStore(SQLBaseStore):
)
return {
- row["account_data_type"]: json.loads(row["content"]) for row in rows
+ row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
@cached(num_args=3, max_entries=5000)
- def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
+ async def get_account_data_for_room_and_type(
+ self, user_id: str, room_id: str, account_data_type: str
+ ) -> Optional[JsonDict]:
"""Get the client account_data of given type for a user for a room.
Args:
- user_id(str): The user to get the account_data for.
- room_id(str): The room to get the account_data for.
- account_data_type (str): The account data type to get.
+ user_id: The user to get the account_data for.
+ room_id: The room to get the account_data for.
+ account_data_type: The account data type to get.
Returns:
- A deferred of the room account_data for that type, or None if
- there isn't any set.
+ The room account_data for that type, or None if there isn't any set.
"""
def get_account_data_for_room_and_type_txn(txn):
- content_json = self.db.simple_select_one_onecol_txn(
+ content_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="room_account_data",
keyvalues={
@@ -170,9 +175,9 @@ class AccountDataWorkerStore(SQLBaseStore):
allow_none=True,
)
- return json.loads(content_json) if content_json else None
+ return db_to_json(content_json) if content_json else None
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
@@ -202,7 +207,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_updated_global_account_data", get_updated_global_account_data_txn
)
@@ -232,16 +237,18 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_updated_room_account_data", get_updated_room_account_data_txn
)
- def get_updated_account_data_for_user(self, user_id, stream_id):
+ async def get_updated_account_data_for_user(
+ self, user_id: str, stream_id: int
+ ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a that's changed for a user
Args:
- user_id(str): The user to get the account_data for.
- stream_id(int): The point in the stream since which to get updates
+ user_id: The user to get the account_data for.
+ stream_id: The point in the stream since which to get updates
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
@@ -255,7 +262,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
- global_account_data = {row[0]: json.loads(row[1]) for row in txn}
+ global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
sql = (
"SELECT room_id, account_data_type, content FROM room_account_data"
@@ -267,7 +274,7 @@ class AccountDataWorkerStore(SQLBaseStore):
account_data_by_room = {}
for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {})
- room_account_data[row[1]] = json.loads(row[2])
+ room_account_data[row[1]] = db_to_json(row[2])
return global_account_data, account_data_by_room
@@ -275,15 +282,17 @@ class AccountDataWorkerStore(SQLBaseStore):
user_id, int(stream_id)
)
if not changed:
- return defer.succeed(({}, {}))
+ return ({}, {})
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
- @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
- def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
- ignored_account_data = yield self.get_global_account_data_by_type_for_user(
+ @cached(num_args=2, cache_context=True, max_entries=5000)
+ async def is_ignored_by(
+ self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
+ ) -> bool:
+ ignored_account_data = await self.get_global_account_data_by_type_for_user(
"m.ignored_user_list",
ignorer_user_id,
on_invalidate=cache_context.invalidate,
@@ -295,7 +304,7 @@ class AccountDataWorkerStore(SQLBaseStore):
class AccountDataStore(AccountDataWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator(
db_conn,
"account_data_max_stream_id",
@@ -308,32 +317,35 @@ class AccountDataStore(AccountDataWorkerStore):
super(AccountDataStore, self).__init__(database, db_conn, hs)
- def get_max_account_data_stream_id(self):
+ def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream id for the private user data stream
Returns:
- A deferred int.
+ The maximum stream ID.
"""
return self._account_data_id_gen.get_current_token()
- @defer.inlineCallbacks
- def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
+ async def add_account_data_to_room(
+ self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
+ ) -> int:
"""Add some account_data to a room for a user.
+
Args:
- user_id(str): The user to add a tag for.
- room_id(str): The room to add a tag for.
- account_data_type(str): The type of account_data to add.
- content(dict): A json object to associate with the tag.
+ user_id: The user to add a tag for.
+ room_id: The room to add a tag for.
+ account_data_type: The type of account_data to add.
+ content: A json object to associate with the tag.
+
Returns:
- A deferred that completes once the account_data has been added.
+ The maximum stream ID.
"""
- content_json = json.dumps(content)
+ content_json = json_encoder.encode(content)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
- yield self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
desc="add_room_account_data",
table="room_account_data",
keyvalues={
@@ -351,7 +363,7 @@ class AccountDataStore(AccountDataWorkerStore):
# doesn't sound any worse than the whole update getting lost,
# which is what would happen if we combined the two into one
# transaction.
- yield self._update_max_stream_id(next_id)
+ await self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
@@ -360,26 +372,28 @@ class AccountDataStore(AccountDataWorkerStore):
(user_id, room_id, account_data_type), content
)
- result = self._account_data_id_gen.get_current_token()
- return result
+ return self._account_data_id_gen.get_current_token()
- @defer.inlineCallbacks
- def add_account_data_for_user(self, user_id, account_data_type, content):
+ async def add_account_data_for_user(
+ self, user_id: str, account_data_type: str, content: JsonDict
+ ) -> int:
"""Add some account_data to a room for a user.
+
Args:
- user_id(str): The user to add a tag for.
- account_data_type(str): The type of account_data to add.
- content(dict): A json object to associate with the tag.
+ user_id: The user to add a tag for.
+ account_data_type: The type of account_data to add.
+ content: A json object to associate with the tag.
+
Returns:
- A deferred that completes once the account_data has been added.
+ The maximum stream ID.
"""
- content_json = json.dumps(content)
+ content_json = json_encoder.encode(content)
- with self._account_data_id_gen.get_next() as next_id:
+ with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
- yield self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
desc="add_user_account_data",
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
@@ -397,7 +411,7 @@ class AccountDataStore(AccountDataWorkerStore):
# Note: This is only here for backwards compat to allow admins to
# roll back to a previous Synapse version. Next time we update the
# database version we can remove this table.
- yield self._update_max_stream_id(next_id)
+ await self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
@@ -405,14 +419,13 @@ class AccountDataStore(AccountDataWorkerStore):
(account_data_type, user_id)
)
- result = self._account_data_id_gen.get_current_token()
- return result
+ return self._account_data_id_gen.get_current_token()
- def _update_max_stream_id(self, next_id):
+ async def _update_max_stream_id(self, next_id: int) -> None:
"""Update the max stream_id
Args:
- next_id(int): The the revision to advance to.
+ next_id: The the revision to advance to.
"""
# Note: This is only here for backwards compat to allow admins to
@@ -427,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore):
)
txn.execute(update_max_id_sql, (next_id, next_id))
- return self.db.runInteraction("update_account_data_max_stream_id", _update)
+ await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 7a1fe8cdd2..454c0bc50c 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -16,15 +16,12 @@
import logging
import re
-from canonicaljson import json
-
-from twisted.internet import defer
-
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.database import Database
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util import json_encoder
logger = logging.getLogger(__name__)
@@ -49,7 +46,7 @@ def _make_exclusive_regex(services_cache):
class ApplicationServiceWorkerStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files
)
@@ -124,17 +121,15 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
class ApplicationServiceTransactionWorkerStore(
ApplicationServiceWorkerStore, EventsWorkerStore
):
- @defer.inlineCallbacks
- def get_appservices_by_state(self, state):
+ async def get_appservices_by_state(self, state):
"""Get a list of application services based on their state.
Args:
state(ApplicationServiceState): The state to filter on.
Returns:
- A Deferred which resolves to a list of ApplicationServices, which
- may be empty.
+ A list of ApplicationServices, which may be empty.
"""
- results = yield self.db.simple_select_list(
+ results = await self.db_pool.simple_select_list(
"application_services_state", {"state": state}, ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
@@ -147,16 +142,15 @@ class ApplicationServiceTransactionWorkerStore(
services.append(service)
return services
- @defer.inlineCallbacks
- def get_appservice_state(self, service):
+ async def get_appservice_state(self, service):
"""Get the application service state.
Args:
service(ApplicationService): The service whose state to set.
Returns:
- A Deferred which resolves to ApplicationServiceState.
+ An ApplicationServiceState.
"""
- result = yield self.db.simple_select_one(
+ result = await self.db_pool.simple_select_one(
"application_services_state",
{"as_id": service.id},
["state"],
@@ -167,20 +161,18 @@ class ApplicationServiceTransactionWorkerStore(
return result.get("state")
return None
- def set_appservice_state(self, service, state):
+ async def set_appservice_state(self, service, state) -> None:
"""Set the application service state.
Args:
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
- Returns:
- A Deferred which resolves when the state was set successfully.
"""
- return self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
)
- def create_appservice_txn(self, service, events):
+ async def create_appservice_txn(self, service, events):
"""Atomically creates a new transaction for this application service
with the given list of events.
@@ -209,7 +201,7 @@ class ApplicationServiceTransactionWorkerStore(
new_txn_id = max(highest_txn_id, last_txn_id) + 1
# Insert new txn into txn table
- event_ids = json.dumps([e.event_id for e in events])
+ event_ids = json_encoder.encode([e.event_id for e in events])
txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
@@ -217,18 +209,17 @@ class ApplicationServiceTransactionWorkerStore(
)
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
- return self.db.runInteraction("create_appservice_txn", _create_appservice_txn)
+ return await self.db_pool.runInteraction(
+ "create_appservice_txn", _create_appservice_txn
+ )
- def complete_appservice_txn(self, txn_id, service):
+ async def complete_appservice_txn(self, txn_id, service) -> None:
"""Completes an application service transaction.
Args:
txn_id(str): The transaction ID being completed.
service(ApplicationService): The application service which was sent
this transaction.
- Returns:
- A Deferred which resolves if this transaction was stored
- successfully.
"""
txn_id = int(txn_id)
@@ -250,7 +241,7 @@ class ApplicationServiceTransactionWorkerStore(
)
# Set current txn_id for AS to 'txn_id'
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
"application_services_state",
{"as_id": service.id},
@@ -258,26 +249,24 @@ class ApplicationServiceTransactionWorkerStore(
)
# Delete txn
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
"application_services_txns",
{"txn_id": txn_id, "as_id": service.id},
)
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"complete_appservice_txn", _complete_appservice_txn
)
- @defer.inlineCallbacks
- def get_oldest_unsent_txn(self, service):
+ async def get_oldest_unsent_txn(self, service):
"""Get the oldest transaction which has not been sent for this
service.
Args:
service(ApplicationService): The app service to get the oldest txn.
Returns:
- A Deferred which resolves to an AppServiceTransaction or
- None.
+ An AppServiceTransaction or None.
"""
def _get_oldest_unsent_txn(txn):
@@ -288,7 +277,7 @@ class ApplicationServiceTransactionWorkerStore(
" ORDER BY txn_id ASC LIMIT 1",
(service.id,),
)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if not rows:
return None
@@ -296,16 +285,16 @@ class ApplicationServiceTransactionWorkerStore(
return entry
- entry = yield self.db.runInteraction(
+ entry = await self.db_pool.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
)
if not entry:
return None
- event_ids = json.loads(entry["event_ids"])
+ event_ids = db_to_json(entry["event_ids"])
- events = yield self.get_events_as_list(event_ids)
+ events = await self.get_events_as_list(event_ids)
return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
@@ -320,18 +309,17 @@ class ApplicationServiceTransactionWorkerStore(
else:
return int(last_txn_id[0]) # select 'last_txn' col
- def set_appservice_last_pos(self, pos):
+ async def set_appservice_last_pos(self, pos) -> None:
def set_appservice_last_pos_txn(txn):
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
- @defer.inlineCallbacks
- def get_new_events_for_appservice(self, current_id, limit):
+ async def get_new_events_for_appservice(self, current_id, limit):
"""Get all new evnets"""
def get_new_events_for_appservice_txn(txn):
@@ -355,11 +343,11 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, [row[1] for row in rows]
- upper_bound, event_ids = yield self.db.runInteraction(
+ upper_bound, event_ids = await self.db_pool.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)
- events = yield self.get_events_as_list(event_ids)
+ events = await self.get_events_as_list(event_ids)
return upper_bound, events
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/databases/main/cache.py
index eac5a4e55b..1e7637a6f5 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -16,15 +16,17 @@
import itertools
import logging
-from typing import Any, Iterable, Optional, Tuple
+from typing import Any, Iterable, List, Optional, Tuple
from synapse.api.constants import EventTypes
+from synapse.replication.tcp.streams import BackfillStream, CachesStream
from synapse.replication.tcp.streams.events import (
+ EventsStream,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
)
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
@@ -37,20 +39,37 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ):
- """Fetches cache invalidation rows between the two given IDs written
- by the given instance. Returns at most `limit` rows.
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for caches replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
- return []
+ return [], current_id, False
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
@@ -64,17 +83,24 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (last_id, instance_name, limit))
- return txn.fetchall()
+ updates = [(row[0], row[1:]) for row in txn]
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "events":
+ if stream_name == EventsStream.NAME:
for row in rows:
self._process_event_stream_row(token, row)
- elif stream_name == "backfill":
+ elif stream_name == BackfillStream.NAME:
for row in rows:
self._invalidate_caches_for_event(
-token,
@@ -86,7 +112,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
row.relates_to,
backfilled=True,
)
- elif stream_name == "caches":
+ elif stream_name == CachesStream.NAME:
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)
@@ -176,7 +202,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
return
cache_func.invalidate(keys)
- await self.db.runInteraction(
+ await self.db_pool.runInteraction(
"invalidate_cache_and_stream",
self._send_invalidation_to_replication,
cache_func.__name__,
@@ -261,7 +287,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if keys is not None:
keys = list(keys)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="cache_invalidation_stream_by_instance",
values={
@@ -273,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
},
)
- def get_cache_stream_token(self, instance_name):
+ def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
if self._cache_id_gen:
- return self._cache_id_gen.get_current_token(instance_name)
+ return self._cache_id_gen.get_current_token_for_writer(instance_name)
else:
return 0
diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 2d48261724..f211ddbaf8 100644
--- a/synapse/storage/data_stores/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -16,15 +16,13 @@
import logging
from typing import TYPE_CHECKING
-from twisted.internet import defer
-
from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.data_stores.main.events import encode_json
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.events import encode_json
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -34,7 +32,7 @@ logger = logging.getLogger(__name__)
class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs: "HomeServer"):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
def _censor_redactions():
@@ -56,7 +54,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
return
if not (
- await self.db.updates.has_completed_background_update(
+ await self.db_pool.updates.has_completed_background_update(
"redactions_have_censored_ts_idx"
)
):
@@ -85,7 +83,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
LIMIT ?
"""
- rows = await self.db.execute(
+ rows = await self.db_pool.execute(
"_censor_redactions_fetch", None, sql, before_ts, 100
)
@@ -123,14 +121,14 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
if pruned_json:
self._censor_event_txn(txn, event_id, pruned_json)
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="redactions",
keyvalues={"event_id": redaction_id},
updatevalues={"have_censored": True},
)
- await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+ await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn)
def _censor_event_txn(self, txn, event_id, pruned_json):
"""Censor an event by replacing its JSON in the event_json table with the
@@ -141,24 +139,23 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
event_id (str): The ID of the event to censor.
pruned_json (str): The pruned JSON
"""
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="event_json",
keyvalues={"event_id": event_id},
updatevalues={"json": pruned_json},
)
- @defer.inlineCallbacks
- def expire_event(self, event_id):
+ async def expire_event(self, event_id: str) -> None:
"""Retrieve and expire an event that has expired, and delete its associated
expiry timestamp. If the event can't be retrieved, delete its associated
timestamp so we don't try to expire it again in the future.
Args:
- event_id (str): The ID of the event to delete.
+ event_id: The ID of the event to delete.
"""
# Try to retrieve the event's content from the database or the event cache.
- event = yield self.get_event(event_id)
+ event = await self.get_event(event_id)
def delete_expired_event_txn(txn):
# Delete the expiry timestamp associated with this event from the database.
@@ -193,7 +190,9 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
txn, "_get_event_cache", (event.event_id,)
)
- yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
+ await self.db_pool.runInteraction(
+ "delete_expired_event", delete_expired_event_txn
+ )
def _delete_event_expiry_txn(self, txn, event_id):
"""Delete the expiry timestamp associated with an event ID without deleting the
@@ -203,6 +202,6 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
txn (LoggingTransaction): The transaction to use to perform the deletion.
event_id (str): The event ID to delete the associated expiry timestamp of.
"""
- return self.db.simple_delete_txn(
+ return self.db_pool.simple_delete_txn(
txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
)
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 71f8d43a76..c2fc847fbc 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -14,14 +14,11 @@
# limitations under the License.
import logging
-
-from six import iteritems
-
-from twisted.internet import defer
+from typing import Dict, Optional, Tuple
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database, make_tuple_comparison_clause
+from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
from synapse.util.caches.descriptors import Cache
logger = logging.getLogger(__name__)
@@ -33,40 +30,40 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"user_ips_device_index",
index_name="user_ips_device_id",
table="user_ips",
columns=["user_id", "device_id", "last_seen"],
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"user_ips_last_seen_index",
index_name="user_ips_last_seen",
table="user_ips",
columns=["user_id", "last_seen"],
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"user_ips_last_seen_only_index",
index_name="user_ips_last_seen_only",
table="user_ips",
columns=["last_seen"],
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"user_ips_analyze", self._analyze_user_ip
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"user_ips_remove_dupes", self._remove_user_ip_dupes
)
# Register a unique index
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"user_ips_device_unique_index",
index_name="user_ips_user_token_ip_unique_index",
table="user_ips",
@@ -75,28 +72,28 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
)
# Drop the old non-unique index
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
)
# Update the last seen info in devices.
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"devices_last_seen", self._devices_last_seen_update
)
- @defer.inlineCallbacks
- def _remove_user_ip_nonunique(self, progress, batch_size):
+ async def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close()
- yield self.db.runWithConnection(f)
- yield self.db.updates._end_background_update("user_ips_drop_nonunique_index")
+ await self.db_pool.runWithConnection(f)
+ await self.db_pool.updates._end_background_update(
+ "user_ips_drop_nonunique_index"
+ )
return 1
- @defer.inlineCallbacks
- def _analyze_user_ip(self, progress, batch_size):
+ async def _analyze_user_ip(self, progress, batch_size):
# Background update to analyze user_ips table before we run the
# deduplication background update. The table may not have been analyzed
# for ages due to the table locks.
@@ -106,14 +103,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
def user_ips_analyze(txn):
txn.execute("ANALYZE user_ips")
- yield self.db.runInteraction("user_ips_analyze", user_ips_analyze)
+ await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
- yield self.db.updates._end_background_update("user_ips_analyze")
+ await self.db_pool.updates._end_background_update("user_ips_analyze")
return 1
- @defer.inlineCallbacks
- def _remove_user_ip_dupes(self, progress, batch_size):
+ async def _remove_user_ip_dupes(self, progress, batch_size):
# This works function works by scanning the user_ips table in batches
# based on `last_seen`. For each row in a batch it searches the rest of
# the table to see if there are any duplicates, if there are then they
@@ -140,7 +136,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return None
# Get a last seen that has roughly `batch_size` since `begin_last_seen`
- end_last_seen = yield self.db.runInteraction(
+ end_last_seen = await self.db_pool.runInteraction(
"user_ips_dups_get_last_seen", get_last_seen
)
@@ -271,19 +267,18 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
(user_id, access_token, ip, device_id, user_agent, last_seen),
)
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
)
- yield self.db.runInteraction("user_ips_dups_remove", remove)
+ await self.db_pool.runInteraction("user_ips_dups_remove", remove)
if last:
- yield self.db.updates._end_background_update("user_ips_remove_dupes")
+ await self.db_pool.updates._end_background_update("user_ips_remove_dupes")
return batch_size
- @defer.inlineCallbacks
- def _devices_last_seen_update(self, progress, batch_size):
+ async def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table
"""
@@ -338,7 +333,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
txn.execute_batch(sql, rows)
_, _, _, user_id, device_id = rows[-1]
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn,
"devices_last_seen",
{"last_user_id": user_id, "last_device_id": device_id},
@@ -346,18 +341,18 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return len(rows)
- updated = yield self.db.runInteraction(
+ updated = await self.db_pool.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)
if not updated:
- yield self.db.updates._end_background_update("devices_last_seen")
+ await self.db_pool.updates._end_background_update("devices_last_seen")
return updated
class ClientIpStore(ClientIpBackgroundUpdateStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000
@@ -380,8 +375,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
- @defer.inlineCallbacks
- def insert_client_ip(
+ async def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
):
if not now:
@@ -392,7 +386,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
- yield self.populate_monthly_active_users(user_id)
+ await self.populate_monthly_active_users(user_id)
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
@@ -402,30 +396,30 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
@wrap_as_background_process("update_client_ips")
- def _update_client_ips_batch(self):
+ async def _update_client_ips_batch(self) -> None:
# If the DB pool has already terminated, don't try updating
- if not self.db.is_running():
+ if not self.db_pool.is_running():
return
to_update = self._batch_row_update
self._batch_row_update = {}
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
def _update_client_ips_batch_txn(self, txn, to_update):
- if "user_ips" in self.db._unsafe_to_upsert_tables or (
+ if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert
):
self.database_engine.lock_table(txn, "user_ips")
- for entry in iteritems(to_update):
+ for entry in to_update.items():
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
try:
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="user_ips",
keyvalues={
@@ -447,7 +441,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# this is always an update rather than an upsert: the row should
# already exist, and if it doesn't, that may be because it has been
# deleted, and we don't want to re-create it.
- self.db.simple_update_txn(
+ self.db_pool.simple_update_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -461,25 +455,25 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Failed to upsert, log and continue
logger.error("Failed to insert client IP %r: %r", entry, e)
- @defer.inlineCallbacks
- def get_last_client_ip_by_device(self, user_id, device_id):
+ async def get_last_client_ip_by_device(
+ self, user_id: str, device_id: Optional[str]
+ ) -> Dict[Tuple[str, str], dict]:
"""For each device_id listed, give the user_ip it was last seen on
Args:
- user_id (str)
- device_id (str): If None fetches all devices for the user
+ user_id: The user to fetch devices for.
+ device_id: If None fetches all devices for the user
Returns:
- defer.Deferred: resolves to a dict, where the keys
- are (user_id, device_id) tuples. The values are also dicts, with
- keys giving the column names
+ A dictionary mapping a tuple of (user_id, device_id) to dicts, with
+ keys giving the column names from the devices table.
"""
keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
- res = yield self.db.simple_select_list(
+ res = await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
@@ -501,8 +495,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
}
return ret
- @defer.inlineCallbacks
- def get_user_ip_and_agents(self, user):
+ async def get_user_ip_and_agents(self, user):
user_id = user.to_string()
results = {}
@@ -512,7 +505,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)
- rows = yield self.db.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "last_seen"],
@@ -530,7 +523,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
"user_agent": user_agent,
"last_seen": last_seen,
}
- for (access_token, ip), (user_agent, last_seen) in iteritems(results)
+ for (access_token, ip), (user_agent, last_seen) in results.items()
]
@wrap_as_background_process("prune_old_user_ips")
@@ -542,7 +535,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Nothing to do
return
- if not await self.db.updates.has_completed_background_update(
+ if not await self.db_pool.updates.has_completed_background_update(
"devices_last_seen"
):
# Only start pruning if we have finished populating the devices
@@ -575,4 +568,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
def _prune_old_user_ips_txn(txn):
txn.execute(sql, (timestamp,))
- await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)
+ await self.db_pool.runInteraction(
+ "_prune_old_user_ips", _prune_old_user_ips_txn
+ )
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 9a1178fb39..0044433110 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,14 +14,12 @@
# limitations under the License.
import logging
-
-from canonicaljson import json
-
-from twisted.internet import defer
+from typing import List, Tuple
from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.database import DatabasePool
+from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
logger = logging.getLogger(__name__)
@@ -31,24 +29,31 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
- def get_new_messages_for_device(
- self, user_id, device_id, last_stream_id, current_stream_id, limit=100
- ):
+ async def get_new_messages_for_device(
+ self,
+ user_id: str,
+ device_id: str,
+ last_stream_id: int,
+ current_stream_id: int,
+ limit: int = 100,
+ ) -> Tuple[List[dict], int]:
"""
Args:
- user_id(str): The recipient user_id.
- device_id(str): The recipient device_id.
- current_stream_id(int): The current position of the to device
+ user_id: The recipient user_id.
+ device_id: The recipient device_id.
+ last_stream_id: The last stream ID checked.
+ current_stream_id: The current position of the to device
message stream.
+ limit: The maximum number of messages to retrieve.
+
Returns:
- Deferred ([dict], int): List of messages for the device and where
- in the stream the messages got to.
+ A list of messages for the device and where in the stream the messages got to.
"""
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
)
if not has_changed:
- return defer.succeed(([], current_stream_id))
+ return ([], current_stream_id)
def get_new_messages_for_device_txn(txn):
sql = (
@@ -64,25 +69,27 @@ class DeviceInboxWorkerStore(SQLBaseStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(json.loads(row[1]))
+ messages.append(db_to_json(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return messages, stream_pos
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn
)
@trace
- @defer.inlineCallbacks
- def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
+ async def delete_messages_for_device(
+ self, user_id: str, device_id: str, up_to_stream_id: int
+ ) -> int:
"""
Args:
- user_id(str): The recipient user_id.
- device_id(str): The recipient device_id.
- up_to_stream_id(int): Where to delete messages up to.
+ user_id: The recipient user_id.
+ device_id: The recipient device_id.
+ up_to_stream_id: Where to delete messages up to.
+
Returns:
- A deferred that resolves to the number of messages deleted.
+ The number of messages deleted.
"""
# If we have cached the last stream id we've deleted up to, we can
# check if there is likely to be anything that needs deleting
@@ -109,7 +116,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
- count = yield self.db.runInteraction(
+ count = await self.db_pool.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
@@ -128,9 +135,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return count
@trace
- def get_new_device_msgs_for_remote(
+ async def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit
- ):
+ ) -> Tuple[List[dict], int]:
"""
Args:
destination(str): The name of the remote server.
@@ -139,8 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
current_stream_id(int|long): The current position of the device
message stream.
Returns:
- Deferred ([dict], int|long): List of messages for the device and where
- in the stream the messages got to.
+ A list of messages for the device and where in the stream the messages got to.
"""
set_tag("destination", destination)
@@ -153,11 +159,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
if not has_changed or last_stream_id == current_stream_id:
log_kv({"message": "No new messages in stream"})
- return defer.succeed(([], current_stream_id))
+ return ([], current_stream_id)
if limit <= 0:
# This can happen if we run out of room for EDUs in the transaction.
- return defer.succeed(([], last_stream_id))
+ return ([], last_stream_id)
@trace
def get_new_messages_for_remote_destination_txn(txn):
@@ -172,27 +178,27 @@ class DeviceInboxWorkerStore(SQLBaseStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(json.loads(row[1]))
+ messages.append(db_to_json(row[1]))
if len(messages) < limit:
log_kv({"message": "Set stream position to current position"})
stream_pos = current_stream_id
return messages, stream_pos
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
@trace
- def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
+ async def delete_device_msgs_for_remote(
+ self, destination: str, up_to_stream_id: int
+ ) -> None:
"""Used to delete messages when the remote destination acknowledges
their receipt.
Args:
- destination(str): The destination server_name
- up_to_stream_id(int): Where to delete messages up to.
- Returns:
- A deferred that resolves when the messages have been deleted.
+ destination: The destination server_name
+ up_to_stream_id: Where to delete messages up to.
"""
def delete_messages_for_remote_destination_txn(txn):
@@ -203,35 +209,50 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
txn.execute(sql, (destination, up_to_stream_id))
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
- def get_all_new_device_messages(self, last_pos, current_pos, limit):
- """
+ async def get_all_new_device_messages(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for to device replication stream.
+
Args:
- last_pos(int):
- current_pos(int):
- limit(int):
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
Returns:
- A deferred list of rows from the device inbox
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
- if last_pos == current_pos:
- return defer.succeed([])
+
+ if last_id == current_id:
+ return [], current_id, False
def get_all_new_device_messages_txn(txn):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
- upper_pos = min(current_pos, last_pos + limit)
+ upper_pos = min(current_id, last_id + limit)
sql = (
"SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id"
)
- txn.execute(sql, (last_pos, upper_pos))
- rows = txn.fetchall()
+ txn.execute(sql, (last_id, upper_pos))
+ updates = [(row[0], row[1:]) for row in txn]
sql = (
"SELECT max(stream_id), destination"
@@ -239,15 +260,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination"
)
- txn.execute(sql, (last_pos, upper_pos))
- rows.extend(txn)
+ txn.execute(sql, (last_id, upper_pos))
+ updates.extend((row[0], row[1:]) for row in txn)
# Order by ascending stream ordering
- rows.sort()
+ updates.sort()
+
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
- return rows
+ return updates, upto_token, limited
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
@@ -255,30 +282,29 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"device_inbox_stream_index",
index_name="device_inbox_stream_id_user_id",
table="device_inbox",
columns=["stream_id", "user_id"],
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
- @defer.inlineCallbacks
- def _background_drop_index_device_inbox(self, progress, batch_size):
+ async def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
- yield self.db.runWithConnection(reindex_txn)
+ await self.db_pool.runWithConnection(reindex_txn)
- yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+ await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1
@@ -286,7 +312,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(DeviceInboxStore, self).__init__(database, db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been
@@ -299,21 +325,21 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
)
@trace
- @defer.inlineCallbacks
- def add_messages_to_device_inbox(
- self, local_messages_by_user_then_device, remote_messages_by_destination
- ):
+ async def add_messages_to_device_inbox(
+ self,
+ local_messages_by_user_then_device: dict,
+ remote_messages_by_destination: dict,
+ ) -> int:
"""Used to send messages from this server.
Args:
- sender_user_id(str): The ID of the user sending these messages.
- local_messages_by_user_and_device(dict):
+ local_messages_by_user_and_device:
Dictionary of user_id to device_id to message.
- remote_messages_by_destination(dict):
+ remote_messages_by_destination:
Dictionary of destination server_name to the EDU JSON to send.
+
Returns:
- A deferred stream_id that resolves when the messages have been
- inserted.
+ The new stream_id.
"""
def add_messages_txn(txn, now_ms, stream_id):
@@ -332,13 +358,13 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
)
rows = []
for destination, edu in remote_messages_by_destination.items():
- edu_json = json.dumps(edu)
+ edu_json = json_encoder.encode(edu)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
- with self._device_inbox_id_gen.get_next() as stream_id:
+ with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
for user_id in local_messages_by_user_then_device.keys():
@@ -350,15 +376,14 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
return self._device_inbox_id_gen.get_current_token()
- @defer.inlineCallbacks
- def add_messages_from_remote_to_device_inbox(
- self, origin, message_id, local_messages_by_user_then_device
- ):
+ async def add_messages_from_remote_to_device_inbox(
+ self, origin: str, message_id: str, local_messages_by_user_then_device: dict
+ ) -> int:
def add_messages_txn(txn, now_ms, stream_id):
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
- already_inserted = self.db.simple_select_one_txn(
+ already_inserted = self.db_pool.simple_select_one_txn(
txn,
table="device_federation_inbox",
keyvalues={"origin": origin, "message_id": message_id},
@@ -370,7 +395,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Add an entry for this message_id so that we know we've processed
# it.
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="device_federation_inbox",
values={
@@ -386,9 +411,9 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device
)
- with self._device_inbox_id_gen.get_next() as stream_id:
+ with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
now_ms,
@@ -402,9 +427,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
):
- sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?"
- txn.execute(sql, (stream_id, stream_id))
-
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
@@ -413,7 +435,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Handle wildcard device_ids.
sql = "SELECT device_id FROM devices WHERE user_id = ?"
txn.execute(sql, (user_id,))
- message_json = json.dumps(messages_by_device["*"])
+ message_json = json_encoder.encode(messages_by_device["*"])
for row in txn:
# Add the message for all devices for this user on this
# server.
@@ -435,7 +457,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Only insert into the local inbox if the device exists on
# this server
device = row[0]
- message_json = json.dumps(messages_by_device[device])
+ message_json = json_encoder.encode(messages_by_device[device])
messages_json_for_user[device] = message_json
if messages_json_for_user:
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/databases/main/devices.py
index fb9f798e29..add4e3ea0e 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -14,14 +14,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
-from typing import List, Optional, Set, Tuple
-
-from six import iteritems
-
-from canonicaljson import json
-
-from twisted.internet import defer
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@@ -33,17 +28,13 @@ from synapse.logging.opentracing import (
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
- Database,
+ DatabasePool,
LoggingTransaction,
make_tuple_comparison_clause,
)
-from synapse.types import Collection, get_verify_key_from_cross_signing_key
-from synapse.util.caches.descriptors import (
- Cache,
- cached,
- cachedInlineCallbacks,
- cachedList,
-)
+from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
+from synapse.util import json_encoder
+from synapse.util.caches.descriptors import Cache, cached, cachedList
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -57,38 +48,36 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
- def get_device(self, user_id, device_id):
+ async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
Args:
- user_id (str): The ID of the user which owns the device
- device_id (str): The ID of the device to retrieve
+ user_id: The ID of the user which owns the device
+ device_id: The ID of the device to retrieve
Returns:
- defer.Deferred for a dict containing the device information
+ A dict containing the device information
Raises:
StoreError: if the device is not found
"""
- return self.db.simple_select_one(
+ return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
)
- @defer.inlineCallbacks
- def get_devices_by_user(self, user_id):
+ async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
"""Retrieve all of a user's registered devices. Only returns devices
that are not marked as hidden.
Args:
- user_id (str):
+ user_id:
Returns:
- defer.Deferred: resolves to a dict from device_id to a dict
- containing "device_id", "user_id" and "display_name" for each
- device.
+ A mapping from device_id to a dict containing "device_id", "user_id"
+ and "display_name" for each device.
"""
- devices = yield self.db.simple_select_list(
+ devices = await self.db_pool.simple_select_list(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@@ -98,21 +87,22 @@ class DeviceWorkerStore(SQLBaseStore):
return {d["device_id"]: d for d in devices}
@trace
- @defer.inlineCallbacks
- def get_device_updates_by_remote(self, destination, from_stream_id, limit):
+ async def get_device_updates_by_remote(
+ self, destination: str, from_stream_id: int, limit: int
+ ) -> Tuple[int, List[Tuple[str, dict]]]:
"""Get a stream of device updates to send to the given remote server.
Args:
- destination (str): The host the device updates are intended for
- from_stream_id (int): The minimum stream_id to filter updates by, exclusive
- limit (int): Maximum number of device updates to return
+ destination: The host the device updates are intended for
+ from_stream_id: The minimum stream_id to filter updates by, exclusive
+ limit: Maximum number of device updates to return
+
Returns:
- Deferred[tuple[int, list[tuple[string,dict]]]]:
- current stream id (ie, the stream id of the last update included in the
- response), and the list of updates, where each update is a pair of EDU
- type and EDU contents
+ A mapping from the current stream id (ie, the stream id of the last
+ update included in the response), and the list of updates, where
+ each update is a pair of EDU type and EDU contents.
"""
- now_stream_id = self._device_list_id_gen.get_current_token()
+ now_stream_id = self.get_device_stream_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
@@ -120,7 +110,7 @@ class DeviceWorkerStore(SQLBaseStore):
if not has_changed:
return now_stream_id, []
- updates = yield self.db.runInteraction(
+ updates = await self.db_pool.runInteraction(
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
@@ -139,7 +129,7 @@ class DeviceWorkerStore(SQLBaseStore):
master_key_by_user = {}
self_signing_key_by_user = {}
for user in users:
- cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+ cross_signing_key = await self.get_e2e_cross_signing_key(user, "master")
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key
@@ -152,7 +142,7 @@ class DeviceWorkerStore(SQLBaseStore):
"device_id": verify_key.version,
}
- cross_signing_key = yield self.get_e2e_cross_signing_key(
+ cross_signing_key = await self.get_e2e_cross_signing_key(
user, "self_signing"
)
if cross_signing_key:
@@ -203,12 +193,12 @@ class DeviceWorkerStore(SQLBaseStore):
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)
- results = yield self._get_device_update_edus_by_remote(
+ results = await self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
# add the updated cross-signing keys to the results list
- for user_id, result in iteritems(cross_signing_keys_by_user):
+ for user_id, result in cross_signing_keys_by_user.items():
result["user_id"] = user_id
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
results.append(("org.matrix.signing_key_update", result))
@@ -216,16 +206,21 @@ class DeviceWorkerStore(SQLBaseStore):
return now_stream_id, results
def _get_device_updates_by_remote_txn(
- self, txn, destination, from_stream_id, now_stream_id, limit
+ self,
+ txn: LoggingTransaction,
+ destination: str,
+ from_stream_id: int,
+ now_stream_id: int,
+ limit: int,
):
"""Return device update information for a given remote destination
Args:
- txn (LoggingTransaction): The transaction to execute
- destination (str): The host the device updates are intended for
- from_stream_id (int): The minimum stream_id to filter updates by, exclusive
- now_stream_id (int): The maximum stream_id to filter updates by, inclusive
- limit (int): Maximum number of device updates to return
+ txn: The transaction to execute
+ destination: The host the device updates are intended for
+ from_stream_id: The minimum stream_id to filter updates by, exclusive
+ now_stream_id: The maximum stream_id to filter updates by, inclusive
+ limit: Maximum number of device updates to return
Returns:
List: List of device updates
@@ -241,25 +236,26 @@ class DeviceWorkerStore(SQLBaseStore):
return list(txn)
- @defer.inlineCallbacks
- def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map):
+ async def _get_device_update_edus_by_remote(
+ self,
+ destination: str,
+ from_stream_id: int,
+ query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]],
+ ) -> List[Tuple[str, dict]]:
"""Returns a list of device update EDUs as well as E2EE keys
Args:
- destination (str): The host the device updates are intended for
- from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ destination: The host the device updates are intended for
+ from_stream_id: The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
- user_id/device_id to update stream_id and the relevent json-encoded
+ user_id/device_id to update stream_id and the relevant json-encoded
opentracing context
Returns:
- List[Dict]: List of objects representing an device update EDU
-
+ List of objects representing an device update EDU
"""
devices = (
- yield self.db.runInteraction(
- "_get_e2e_device_keys_txn",
- self._get_e2e_device_keys_txn,
+ await self.get_e2e_device_keys_and_signatures(
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
@@ -269,10 +265,10 @@ class DeviceWorkerStore(SQLBaseStore):
)
results = []
- for user_id, user_devices in iteritems(devices):
+ for user_id, user_devices in devices.items():
# The prev_id for the first row is always the last row before
# `from_stream_id`
- prev_id = yield self._get_last_device_update_for_remote_user(
+ prev_id = await self._get_last_device_update_for_remote_user(
destination, user_id, from_stream_id
)
@@ -295,17 +291,11 @@ class DeviceWorkerStore(SQLBaseStore):
prev_id = stream_id
if device is not None:
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = db_to_json(key_json)
-
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
- result["keys"].setdefault("signatures", {}).setdefault(
- sig_user_id, {}
- ).update(sigs)
+ keys = device.keys
+ if keys:
+ result["keys"] = keys
- device_display_name = device.get("device_display_name", None)
+ device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
else:
@@ -315,9 +305,9 @@ class DeviceWorkerStore(SQLBaseStore):
return results
- def _get_last_device_update_for_remote_user(
- self, destination, user_id, from_stream_id
- ):
+ async def _get_last_device_update_for_remote_user(
+ self, destination: str, user_id: str, from_stream_id: int
+ ) -> int:
def f(txn):
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
@@ -328,19 +318,25 @@ class DeviceWorkerStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0]
- return self.db.runInteraction("get_last_device_update_for_remote_user", f)
+ return await self.db_pool.runInteraction(
+ "get_last_device_update_for_remote_user", f
+ )
- def mark_as_sent_devices_by_remote(self, destination, stream_id):
+ async def mark_as_sent_devices_by_remote(
+ self, destination: str, stream_id: int
+ ) -> None:
"""Mark that updates have successfully been sent to the destination.
"""
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
destination,
stream_id,
)
- def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
+ def _mark_as_sent_devices_by_remote_txn(
+ self, txn: LoggingTransaction, destination: str, stream_id: int
+ ) -> None:
# We update the device_lists_outbound_last_success with the successfully
# poked users.
sql = """
@@ -352,7 +348,7 @@ class DeviceWorkerStore(SQLBaseStore):
txn.execute(sql, (destination, stream_id))
rows = txn.fetchall()
- self.db.simple_upsert_many_txn(
+ self.db_pool.simple_upsert_many_txn(
txn=txn,
table="device_lists_outbound_last_success",
key_names=("destination", "user_id"),
@@ -368,17 +364,21 @@ class DeviceWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (destination, stream_id))
- @defer.inlineCallbacks
- def add_user_signature_change_to_streams(self, from_user_id, user_ids):
+ async def add_user_signature_change_to_streams(
+ self, from_user_id: str, user_ids: List[str]
+ ) -> int:
"""Persist that a user has made new signatures
Args:
- from_user_id (str): the user who made the signatures
- user_ids (list[str]): the users who were signed
+ from_user_id: the user who made the signatures
+ user_ids: the users who were signed
+
+ Returns:
+ THe new stream ID.
"""
- with self._device_list_id_gen.get_next() as stream_id:
- yield self.db.runInteraction(
+ with await self._device_list_id_gen.get_next() as stream_id:
+ await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
from_user_id,
@@ -387,45 +387,54 @@ class DeviceWorkerStore(SQLBaseStore):
)
return stream_id
- def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id):
+ def _add_user_signature_change_txn(
+ self,
+ txn: LoggingTransaction,
+ from_user_id: str,
+ user_ids: List[str],
+ stream_id: int,
+ ) -> None:
txn.call_after(
self._user_signature_stream_cache.entity_has_changed,
from_user_id,
stream_id,
)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
"user_signature_stream",
values={
"stream_id": stream_id,
"from_user_id": from_user_id,
- "user_ids": json.dumps(user_ids),
+ "user_ids": json_encoder.encode(user_ids),
},
)
- def get_device_stream_token(self):
- return self._device_list_id_gen.get_current_token()
+ @abc.abstractmethod
+ def get_device_stream_token(self) -> int:
+ """Get the current stream id from the _device_list_id_gen"""
+ ...
@trace
- @defer.inlineCallbacks
- def get_user_devices_from_cache(self, query_list):
+ async def get_user_devices_from_cache(
+ self, query_list: List[Tuple[str, str]]
+ ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.
Args:
- query_list(list): List of (user_id, device_ids), if device_ids is
+ query_list: List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
Returns:
- (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
- a set of user_ids and results_map is a mapping of
- user_id -> device_id -> device_info
+ A tuple of (user_ids_not_in_cache, results_map), where
+ user_ids_not_in_cache is a set of user_ids and results_map is a
+ mapping of user_id -> device_id -> device_info.
"""
user_ids = {user_id for user_id, _ in query_list}
- user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+ user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
# We go and check if any of the users need to have their device lists
# resynced. If they do then we remove them from the cached list.
- users_needing_resync = yield self.get_user_ids_requiring_device_list_resync(
+ users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
user_ids
)
user_ids_in_cache = {
@@ -439,19 +448,19 @@ class DeviceWorkerStore(SQLBaseStore):
continue
if device_id:
- device = yield self._get_cached_user_device(user_id, device_id)
+ device = await self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
else:
- results[user_id] = yield self.get_cached_devices_for_user(user_id)
+ results[user_id] = await self.get_cached_devices_for_user(user_id)
set_tag("in_cache", results)
set_tag("not_in_cache", user_ids_not_in_cache)
return user_ids_not_in_cache, results
- @cachedInlineCallbacks(num_args=2, tree=True)
- def _get_cached_user_device(self, user_id, device_id):
- content = yield self.db.simple_select_one_onecol(
+ @cached(num_args=2, tree=True)
+ async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
+ content = await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="content",
@@ -459,9 +468,9 @@ class DeviceWorkerStore(SQLBaseStore):
)
return db_to_json(content)
- @cachedInlineCallbacks()
- def get_cached_devices_for_user(self, user_id):
- devices = yield self.db.simple_select_list(
+ @cached()
+ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
+ devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
@@ -471,62 +480,18 @@ class DeviceWorkerStore(SQLBaseStore):
device["device_id"]: db_to_json(device["content"]) for device in devices
}
- def get_devices_with_keys_by_user(self, user_id):
- """Get all devices (with any device keys) for a user
-
- Returns:
- (stream_id, devices)
- """
- return self.db.runInteraction(
- "get_devices_with_keys_by_user",
- self._get_devices_with_keys_by_user_txn,
- user_id,
- )
-
- def _get_devices_with_keys_by_user_txn(self, txn, user_id):
- now_stream_id = self._device_list_id_gen.get_current_token()
-
- devices = self._get_e2e_device_keys_txn(
- txn, [(user_id, None)], include_all_devices=True
- )
-
- if devices:
- user_devices = devices[user_id]
- results = []
- for device_id, device in iteritems(user_devices):
- result = {"device_id": device_id}
-
- key_json = device.get("key_json", None)
- if key_json:
- result["keys"] = db_to_json(key_json)
-
- if "signatures" in device:
- for sig_user_id, sigs in device["signatures"].items():
- result["keys"].setdefault("signatures", {}).setdefault(
- sig_user_id, {}
- ).update(sigs)
-
- device_display_name = device.get("device_display_name", None)
- if device_display_name:
- result["device_display_name"] = device_display_name
-
- results.append(result)
-
- return now_stream_id, results
-
- return now_stream_id, []
-
- def get_users_whose_devices_changed(self, from_key, user_ids):
+ async def get_users_whose_devices_changed(
+ self, from_key: str, user_ids: Iterable[str]
+ ) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
Args:
- from_key (str): The device lists stream token
- user_ids (Iterable[str])
+ from_key: The device lists stream token
+ user_ids: The user IDs to query for devices.
Returns:
- Deferred[set[str]]: The set of user_ids whose devices have changed
- since `from_key`
+ The set of user_ids whose devices have changed since `from_key`
"""
from_key = int(from_key)
@@ -537,7 +502,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
if not to_check:
- return defer.succeed(set())
+ return set()
def _get_users_whose_devices_changed_txn(txn):
changes = set()
@@ -557,18 +522,22 @@ class DeviceWorkerStore(SQLBaseStore):
return changes
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
)
- @defer.inlineCallbacks
- def get_users_whose_signatures_changed(self, user_id, from_key):
+ async def get_users_whose_signatures_changed(
+ self, user_id: str, from_key: str
+ ) -> Set[str]:
"""Get the users who have new cross-signing signatures made by `user_id` since
`from_key`.
Args:
- user_id (str): the user who made the signatures
- from_key (str): The device lists stream token
+ user_id: the user who made the signatures
+ from_key: The device lists stream token
+
+ Returns:
+ A set of user IDs with updated signatures.
"""
from_key = int(from_key)
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
@@ -576,48 +545,76 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT DISTINCT user_ids FROM user_signature_stream
WHERE from_user_id = ? AND stream_id > ?
"""
- rows = yield self.db.execute(
+ rows = await self.db_pool.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
)
- return {user for row in rows for user in json.loads(row[0])}
+ return {user for row in rows for user in db_to_json(row[0])}
else:
return set()
async def get_all_device_list_changes_for_remotes(
- self, from_key: int, to_key: int, limit: int,
- ) -> List[Tuple[int, str]]:
- """Return a list of `(stream_id, entity)` which is the combined list of
- changes to devices and which destinations need to be poked. Entity is
- either a user ID (starting with '@') or a remote destination.
- """
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for device lists replication stream.
- # This query Does The Right Thing where it'll correctly apply the
- # bounds to the inner queries.
- sql = """
- SELECT stream_id, entity FROM (
- SELECT stream_id, user_id AS entity FROM device_lists_stream
- UNION ALL
- SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
- ) AS e
- WHERE ? < stream_id AND stream_id <= ?
- LIMIT ?
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updates.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
- return await self.db.execute(
+ if last_id == current_id:
+ return [], current_id, False
+
+ def _get_all_device_list_changes_for_remotes(txn):
+ # This query Does The Right Thing where it'll correctly apply the
+ # bounds to the inner queries.
+ sql = """
+ SELECT stream_id, entity FROM (
+ SELECT stream_id, user_id AS entity FROM device_lists_stream
+ UNION ALL
+ SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+ ) AS e
+ WHERE ? < stream_id AND stream_id <= ?
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_id, current_id, limit))
+ updates = [(row[0], row[1:]) for row in txn]
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return await self.db_pool.runInteraction(
"get_all_device_list_changes_for_remotes",
- None,
- sql,
- from_key,
- to_key,
- limit,
+ _get_all_device_list_changes_for_remotes,
)
@cached(max_entries=10000)
- def get_device_list_last_stream_id_for_remote(self, user_id):
+ async def get_device_list_last_stream_id_for_remote(
+ self, user_id: str
+ ) -> Optional[Any]:
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
- return self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
@@ -628,10 +625,9 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedList(
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
- inlineCallbacks=True,
)
- def get_device_list_last_stream_id_for_remotes(self, user_ids):
- rows = yield self.db.simple_select_many_batch(
+ async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
+ rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
@@ -644,8 +640,7 @@ class DeviceWorkerStore(SQLBaseStore):
return results
- @defer.inlineCallbacks
- def get_user_ids_requiring_device_list_resync(
+ async def get_user_ids_requiring_device_list_resync(
self, user_ids: Optional[Collection[str]] = None,
) -> Set[str]:
"""Given a list of remote users return the list of users that we
@@ -656,7 +651,7 @@ class DeviceWorkerStore(SQLBaseStore):
The IDs of users whose device lists need resync.
"""
if user_ids:
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync",
column="user_id",
iterable=user_ids,
@@ -664,7 +659,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="get_user_ids_requiring_device_list_resync_with_iterable",
)
else:
- rows = yield self.db.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
table="device_lists_remote_resync",
keyvalues=None,
retcols=("user_id",),
@@ -673,11 +668,11 @@ class DeviceWorkerStore(SQLBaseStore):
return {row["user_id"] for row in rows}
- def mark_remote_user_device_cache_as_stale(self, user_id: str):
+ async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
"""Records that the server has reason to believe the cache of the devices
for the remote users is out of date.
"""
- return self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
table="device_lists_remote_resync",
keyvalues={"user_id": user_id},
values={},
@@ -685,12 +680,12 @@ class DeviceWorkerStore(SQLBaseStore):
desc="make_remote_user_device_cache_as_stale",
)
- def mark_remote_user_device_list_as_unsubscribed(self, user_id):
+ async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user.
"""
def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -699,17 +694,17 @@ class DeviceWorkerStore(SQLBaseStore):
txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
)
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"mark_remote_user_device_list_as_unsubscribed",
_mark_remote_user_device_list_as_unsubscribed_txn,
)
class DeviceBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"device_lists_stream_idx",
index_name="device_lists_stream_user_id",
table="device_lists_stream",
@@ -717,7 +712,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
)
# create a unique index on device_lists_remote_cache
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"device_lists_remote_cache_unique_idx",
index_name="device_lists_remote_cache_unique_id",
table="device_lists_remote_cache",
@@ -726,7 +721,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
)
# And one on device_lists_remote_extremeties
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"device_lists_remote_extremeties_unique_idx",
index_name="device_lists_remote_extremeties_unique_idx",
table="device_lists_remote_extremeties",
@@ -735,35 +730,34 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
)
# once they complete, we can remove the old non-unique indexes.
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
self._drop_device_list_streams_non_unique_indexes,
)
# clear out duplicate device list outbound pokes
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
)
# a pair of background updates that were added during the 1.14 release cycle,
# but replaced with 58/06dlols_unique_idx.py
- self.db.updates.register_noop_background_update(
+ self.db_pool.updates.register_noop_background_update(
"device_lists_outbound_last_success_unique_idx",
)
- self.db.updates.register_noop_background_update(
+ self.db_pool.updates.register_noop_background_update(
"drop_device_lists_outbound_last_success_non_unique_idx",
)
- @defer.inlineCallbacks
- def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
+ async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
- yield self.db.runWithConnection(f)
- yield self.db.updates._end_background_update(
+ await self.db_pool.runWithConnection(f)
+ await self.db_pool.updates._end_background_update(
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
)
return 1
@@ -783,7 +777,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
def _txn(txn):
clause, args = make_tuple_comparison_clause(
- self.db.engine, [(x, last_row[x]) for x in KEY_COLS]
+ self.db_pool.engine, [(x, last_row[x]) for x in KEY_COLS]
)
sql = """
SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
@@ -799,30 +793,32 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
",".join(KEY_COLS), # ORDER BY
)
txn.execute(sql, args + [batch_size])
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
row = None
for row in rows:
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
)
row["sent"] = False
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn, "device_lists_outbound_pokes", row,
)
if row:
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
)
return len(rows)
- rows = await self.db.runInteraction(BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn)
+ rows = await self.db_pool.runInteraction(
+ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn
+ )
if not rows:
- await self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES
)
@@ -830,7 +826,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(DeviceStore, self).__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
@@ -841,18 +837,20 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
- @defer.inlineCallbacks
- def store_device(self, user_id, device_id, initial_device_display_name):
+ async def store_device(
+ self, user_id: str, device_id: str, initial_device_display_name: str
+ ) -> bool:
"""Ensure the given device is known; add it to the store if not
Args:
- user_id (str): id of user associated with the device
- device_id (str): id of device
- initial_device_display_name (str): initial displayname of the
- device. Ignored if device exists.
+ user_id: id of user associated with the device
+ device_id: id of device
+ initial_device_display_name: initial displayname of the device.
+ Ignored if device exists.
+
Returns:
- defer.Deferred: boolean whether the device was inserted or an
- existing device existed with that ID.
+ Whether the device was inserted or an existing device existed with that ID.
+
Raises:
StoreError: if the device is already in use
"""
@@ -861,7 +859,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return False
try:
- inserted = yield self.db.simple_insert(
+ inserted = await self.db_pool.simple_insert(
"devices",
values={
"user_id": user_id,
@@ -875,7 +873,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not inserted:
# if the device already exists, check if it's a real device, or
# if the device ID is reserved by something else
- hidden = yield self.db.simple_select_one_onecol(
+ hidden = await self.db_pool.simple_select_one_onecol(
"devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="hidden",
@@ -900,17 +898,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
raise StoreError(500, "Problem storing device.")
- @defer.inlineCallbacks
- def delete_device(self, user_id, device_id):
+ async def delete_device(self, user_id: str, device_id: str) -> None:
"""Delete a device.
Args:
- user_id (str): The ID of the user which owns the device
- device_id (str): The ID of the device to delete
- Returns:
- defer.Deferred
+ user_id: The ID of the user which owns the device
+ device_id: The ID of the device to delete
"""
- yield self.db.simple_delete_one(
+ await self.db_pool.simple_delete_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
desc="delete_device",
@@ -918,17 +913,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.device_id_exists_cache.invalidate((user_id, device_id))
- @defer.inlineCallbacks
- def delete_devices(self, user_id, device_ids):
+ async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
"""Deletes several devices.
Args:
- user_id (str): The ID of the user which owns the devices
- device_ids (list): The IDs of the devices to delete
- Returns:
- defer.Deferred
+ user_id: The ID of the user which owns the devices
+ device_ids: The IDs of the devices to delete
"""
- yield self.db.simple_delete_many(
+ await self.db_pool.simple_delete_many(
table="devices",
column="device_id",
iterable=device_ids,
@@ -938,50 +930,46 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
- def update_device(self, user_id, device_id, new_display_name=None):
+ async def update_device(
+ self, user_id: str, device_id: str, new_display_name: Optional[str] = None
+ ) -> None:
"""Update a device. Only updates the device if it is not marked as
hidden.
Args:
- user_id (str): The ID of the user which owns the device
- device_id (str): The ID of the device to update
- new_display_name (str|None): new displayname for device; None
- to leave unchanged
+ user_id: The ID of the user which owns the device
+ device_id: The ID of the device to update
+ new_display_name: new displayname for device; None to leave unchanged
Raises:
StoreError: if the device is not found
- Returns:
- defer.Deferred
"""
updates = {}
if new_display_name is not None:
updates["display_name"] = new_display_name
if not updates:
- return defer.succeed(None)
- return self.db.simple_update_one(
+ return None
+ await self.db_pool.simple_update_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
updatevalues=updates,
desc="update_device",
)
- def update_remote_device_list_cache_entry(
- self, user_id, device_id, content, stream_id
- ):
+ async def update_remote_device_list_cache_entry(
+ self, user_id: str, device_id: str, content: JsonDict, stream_id: int
+ ) -> None:
"""Updates a single device in the cache of a remote user's devicelist.
Note: assumes that we are the only thread that can be updating this user's
device list.
Args:
- user_id (str): User to update device list for
- device_id (str): ID of decivice being updated
- content (dict): new data on this device
- stream_id (int): the version of the device list
-
- Returns:
- Deferred[None]
+ user_id: User to update device list for
+ device_id: ID of decivice being updated
+ content: new data on this device
+ stream_id: the version of the device list
"""
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id,
@@ -991,10 +979,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
def _update_remote_device_list_cache_entry_txn(
- self, txn, user_id, device_id, content, stream_id
- ):
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ content: JsonDict,
+ stream_id: int,
+ ) -> None:
if content.get("deleted"):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -1002,11 +995,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
else:
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
- values={"content": json.dumps(content)},
+ values={"content": json_encoder.encode(content)},
# we don't need to lock, because we assume we are the only thread
# updating this user's devices.
lock=False,
@@ -1018,7 +1011,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -1028,21 +1021,20 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
lock=False,
)
- def update_remote_device_list_cache(self, user_id, devices, stream_id):
+ async def update_remote_device_list_cache(
+ self, user_id: str, devices: List[dict], stream_id: int
+ ) -> None:
"""Replace the entire cache of the remote user's devices.
Note: assumes that we are the only thread that can be updating this user's
device list.
Args:
- user_id (str): User to update device list for
- devices (list[dict]): list of device objects supplied over federation
- stream_id (int): the version of the device list
-
- Returns:
- Deferred[None]
+ user_id: User to update device list for
+ devices: list of device objects supplied over federation
+ stream_id: the version of the device list
"""
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id,
@@ -1050,19 +1042,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
stream_id,
)
- def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
- self.db.simple_delete_txn(
+ def _update_remote_device_list_cache_txn(
+ self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
+ ) -> None:
+ self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
values=[
{
"user_id": user_id,
"device_id": content["device_id"],
- "content": json.dumps(content),
+ "content": json_encoder.encode(content),
}
for content in devices
],
@@ -1074,7 +1068,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -1087,20 +1081,23 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# If we're replacing the remote user's device list cache presumably
# we've done a full resync, so we remove the entry that says we need
# to resync
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
)
- @defer.inlineCallbacks
- def add_device_change_to_streams(self, user_id, device_ids, hosts):
+ async def add_device_change_to_streams(
+ self, user_id: str, device_ids: Collection[str], hosts: List[str]
+ ):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
if not device_ids:
return
- with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
- yield self.db.runInteraction(
+ with await self._device_list_id_gen.get_next_mult(
+ len(device_ids)
+ ) as stream_ids:
+ await self.db_pool.runInteraction(
"add_device_change_to_stream",
self._add_device_change_to_stream_txn,
user_id,
@@ -1112,10 +1109,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
context = get_active_span_text_map()
- with self._device_list_id_gen.get_next_mult(
+ with await self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"add_device_outbound_poke_to_stream",
self._add_device_outbound_poke_to_stream_txn,
user_id,
@@ -1150,7 +1147,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
[(user_id, device_id, min_stream_id) for device_id in device_ids],
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
@@ -1160,7 +1157,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
def _add_device_outbound_poke_to_stream_txn(
- self, txn, user_id, device_ids, hosts, stream_ids, context,
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Collection[str],
+ hosts: List[str],
+ stream_ids: List[str],
+ context: Dict[str, str],
):
for host in hosts:
txn.call_after(
@@ -1172,7 +1175,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
now = self._clock.time_msec()
next_stream_id = iter(stream_ids)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
values=[
@@ -1183,7 +1186,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"device_id": device_id,
"sent": False,
"ts": now,
- "opentracing_context": json.dumps(context)
+ "opentracing_context": json_encoder.encode(context)
if whitelisted_homeserver(destination)
else "{}",
}
@@ -1192,7 +1195,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
],
)
- def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
+ def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
that we don't fill up due to dead servers.
@@ -1279,7 +1282,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return run_as_background_process(
"prune_old_outbound_device_pokes",
- self.db.runInteraction,
+ self.db_pool.runInteraction,
"_prune_old_outbound_device_pokes",
_prune_txn,
)
diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/databases/main/directory.py
index e1d1bc3e05..e5060d4c46 100644
--- a/synapse/storage/data_stores/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -14,30 +14,29 @@
# limitations under the License.
from collections import namedtuple
-from typing import Optional
-
-from twisted.internet import defer
+from typing import Iterable, List, Optional
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
+from synapse.types import RoomAlias
from synapse.util.caches.descriptors import cached
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
class DirectoryWorkerStore(SQLBaseStore):
- @defer.inlineCallbacks
- def get_association_from_room_alias(self, room_alias):
- """ Get's the room_id and server list for a given room_alias
+ async def get_association_from_room_alias(
+ self, room_alias: RoomAlias
+ ) -> Optional[RoomAliasMapping]:
+ """Gets the room_id and server list for a given room_alias
Args:
- room_alias (RoomAlias)
+ room_alias: The alias to translate to an ID.
Returns:
- Deferred: results in namedtuple with keys "room_id" and
- "servers" or None if no association can be found
+ The room alias mapping or None if no association can be found.
"""
- room_id = yield self.db.simple_select_one_onecol(
+ room_id = await self.db_pool.simple_select_one_onecol(
"room_aliases",
{"room_alias": room_alias.to_string()},
"room_id",
@@ -48,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore):
if not room_id:
return None
- servers = yield self.db.simple_select_onecol(
+ servers = await self.db_pool.simple_select_onecol(
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
@@ -60,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore):
return RoomAliasMapping(room_id, room_alias.to_string(), servers)
- def get_room_alias_creator(self, room_alias):
- return self.db.simple_select_one_onecol(
+ async def get_room_alias_creator(self, room_alias: str) -> str:
+ return await self.db_pool.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",
@@ -69,8 +68,8 @@ class DirectoryWorkerStore(SQLBaseStore):
)
@cached(max_entries=5000)
- def get_aliases_for_room(self, room_id):
- return self.db.simple_select_onecol(
+ async def get_aliases_for_room(self, room_id: str) -> List[str]:
+ return await self.db_pool.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
@@ -79,22 +78,24 @@ class DirectoryWorkerStore(SQLBaseStore):
class DirectoryStore(DirectoryWorkerStore):
- @defer.inlineCallbacks
- def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
+ async def create_room_alias_association(
+ self,
+ room_alias: RoomAlias,
+ room_id: str,
+ servers: Iterable[str],
+ creator: Optional[str] = None,
+ ) -> None:
""" Creates an association between a room alias and room_id/servers
Args:
- room_alias (RoomAlias)
- room_id (str)
- servers (list)
- creator (str): Optional user_id of creator.
-
- Returns:
- Deferred
+ room_alias: The alias to create.
+ room_id: The target of the alias.
+ servers: A list of servers through which it may be possible to join the room
+ creator: Optional user_id of creator.
"""
def alias_txn(txn):
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
"room_aliases",
{
@@ -104,7 +105,7 @@ class DirectoryStore(DirectoryWorkerStore):
},
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="room_alias_servers",
values=[
@@ -118,24 +119,22 @@ class DirectoryStore(DirectoryWorkerStore):
)
try:
- ret = yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"create_room_alias_association", alias_txn
)
except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
)
- return ret
- @defer.inlineCallbacks
- def delete_room_alias(self, room_alias):
- room_id = yield self.db.runInteraction(
+ async def delete_room_alias(self, room_alias: RoomAlias) -> str:
+ room_id = await self.db_pool.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias
)
return room_id
- def _delete_room_alias_txn(self, txn, room_alias):
+ def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str:
txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),),
@@ -160,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore):
return room_id
- def update_aliases_for_room(
+ async def update_aliases_for_room(
self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
- ):
+ ) -> None:
"""Repoint all of the aliases for a given room, to a different room.
Args:
@@ -190,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore):
txn, self.get_aliases_for_room, (new_room_id,)
)
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 23f4570c4b..12cecceec2 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -14,18 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
-
-from twisted.internet import defer
+from typing import Optional
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.util import json_encoder
class EndToEndRoomKeyStore(SQLBaseStore):
- @defer.inlineCallbacks
- def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
+ async def update_e2e_room_key(
+ self, user_id, version, room_id, session_id, room_key
+ ):
"""Replaces the encrypted E2E room key for a given session in a given backup
Args:
@@ -38,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
StoreError
"""
- yield self.db.simple_update_one(
+ await self.db_pool.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
@@ -50,13 +50,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"first_message_index": room_key["first_message_index"],
"forwarded_count": room_key["forwarded_count"],
"is_verified": room_key["is_verified"],
- "session_data": json.dumps(room_key["session_data"]),
+ "session_data": json_encoder.encode(room_key["session_data"]),
},
desc="update_e2e_room_key",
)
- @defer.inlineCallbacks
- def add_e2e_room_keys(self, user_id, version, room_keys):
+ async def add_e2e_room_keys(self, user_id, version, room_keys):
"""Bulk add room keys to a given backup.
Args:
@@ -77,7 +76,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"first_message_index": room_key["first_message_index"],
"forwarded_count": room_key["forwarded_count"],
"is_verified": room_key["is_verified"],
- "session_data": json.dumps(room_key["session_data"]),
+ "session_data": json_encoder.encode(room_key["session_data"]),
}
)
log_kv(
@@ -89,13 +88,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
}
)
- yield self.db.simple_insert_many(
+ await self.db_pool.simple_insert_many(
table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
)
@trace
- @defer.inlineCallbacks
- def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
@@ -110,7 +108,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
the backup (or for the specified room)
Returns:
- A deferred list of dicts giving the session_data and message metadata for
+ A list of dicts giving the session_data and message metadata for
these room keys.
"""
@@ -125,7 +123,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
- rows = yield self.db.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
@@ -148,12 +146,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"forwarded_count": row["forwarded_count"],
# is_verified must be returned to the client as a boolean
"is_verified": bool(row["is_verified"]),
- "session_data": json.loads(row["session_data"]),
+ "session_data": db_to_json(row["session_data"]),
}
return sessions
- def get_e2e_room_keys_multi(self, user_id, version, room_keys):
+ async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
"""Get multiple room keys at a time. The difference between this function and
get_e2e_room_keys is that this function can be used to retrieve
multiple specific keys at a time, whereas get_e2e_room_keys is used for
@@ -168,10 +166,10 @@ class EndToEndRoomKeyStore(SQLBaseStore):
that we want to query
Returns:
- Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
+ dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn,
user_id,
@@ -222,20 +220,20 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"first_message_index": row[2],
"forwarded_count": row[3],
"is_verified": row[4],
- "session_data": json.loads(row[5]),
+ "session_data": db_to_json(row[5]),
}
return ret
- def count_e2e_room_keys(self, user_id, version):
+ async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
"""Get the number of keys in a backup version.
Args:
- user_id (str): the user whose backup we're querying
- version (str): the version ID of the backup we're querying about
+ user_id: the user whose backup we're querying
+ version: the version ID of the backup we're querying about
"""
- return self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version},
retcol="COUNT(*)",
@@ -243,8 +241,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@trace
- @defer.inlineCallbacks
- def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def delete_e2e_room_keys(
+ self, user_id, version, room_id=None, session_id=None
+ ):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
@@ -259,7 +258,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
the backup (or for the specified room)
Returns:
- A deferred of the deletion transaction
+ The deletion transaction
"""
keyvalues = {"user_id": user_id, "version": int(version)}
@@ -268,7 +267,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
- yield self.db.simple_delete(
+ await self.db_pool.simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
)
@@ -284,7 +283,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
raise StoreError(404, "No current backup version")
return row[0]
- def get_e2e_room_keys_version_info(self, user_id, version=None):
+ async def get_e2e_room_keys_version_info(self, user_id, version=None):
"""Get info metadata about a version of our room_keys backup.
Args:
@@ -294,7 +293,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present
Returns:
- A deferred dict giving the info metadata for this backup version, with
+ A dict giving the info metadata for this backup version, with
fields including:
version(str)
algorithm(str)
@@ -313,24 +312,24 @@ class EndToEndRoomKeyStore(SQLBaseStore):
# it isn't there.
raise StoreError(404, "No row found")
- result = self.db.simple_select_one_txn(
+ result = self.db_pool.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data", "etag"),
)
- result["auth_data"] = json.loads(result["auth_data"])
+ result["auth_data"] = db_to_json(result["auth_data"])
result["version"] = str(result["version"])
if result["etag"] is None:
result["etag"] = 0
return result
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
)
@trace
- def create_e2e_room_keys_version(self, user_id, info):
+ async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str:
"""Atomically creates a new version of this user's e2e_room_keys store
with the given version info.
@@ -339,7 +338,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
info(dict): the info about the backup version to be created
Returns:
- A deferred string for the newly created version ID
+ The newly created version ID
"""
def _create_e2e_room_keys_version_txn(txn):
@@ -353,46 +352,50 @@ class EndToEndRoomKeyStore(SQLBaseStore):
new_version = str(int(current_version) + 1)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="e2e_room_keys_versions",
values={
"user_id": user_id,
"version": new_version,
"algorithm": info["algorithm"],
- "auth_data": json.dumps(info["auth_data"]),
+ "auth_data": json_encoder.encode(info["auth_data"]),
},
)
return new_version
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
)
@trace
- def update_e2e_room_keys_version(
- self, user_id, version, info=None, version_etag=None
- ):
+ async def update_e2e_room_keys_version(
+ self,
+ user_id: str,
+ version: str,
+ info: Optional[dict] = None,
+ version_etag: Optional[int] = None,
+ ) -> None:
"""Update a given backup version
Args:
- user_id(str): the user whose backup version we're updating
- version(str): the version ID of the backup version we're updating
- info (dict): the new backup version info to store. If None, then
- the backup version info is not updated
- version_etag (Optional[int]): etag of the keys in the backup. If
- None, then the etag is not updated
+ user_id: the user whose backup version we're updating
+ version: the version ID of the backup version we're updating
+ info: the new backup version info to store. If None, then the backup
+ version info is not updated.
+ version_etag: etag of the keys in the backup. If None, then the etag
+ is not updated.
"""
updatevalues = {}
if info is not None and "auth_data" in info:
- updatevalues["auth_data"] = json.dumps(info["auth_data"])
+ updatevalues["auth_data"] = json_encoder.encode(info["auth_data"])
if version_etag is not None:
updatevalues["etag"] = version_etag
if updatevalues:
- return self.db.simple_update(
+ await self.db_pool.simple_update(
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version},
updatevalues=updatevalues,
@@ -400,13 +403,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@trace
- def delete_e2e_room_keys_version(self, user_id, version=None):
+ async def delete_e2e_room_keys_version(
+ self, user_id: str, version: Optional[str] = None
+ ) -> None:
"""Delete a given backup version of the user's room keys.
Doesn't delete their actual key data.
Args:
- user_id(str): the user whose backup version we're deleting
- version(str): Optional. the version ID of the backup version we're deleting
+ user_id: the user whose backup version we're deleting
+ version: Optional. the version ID of the backup version we're deleting
If missing, we delete the current backup version info.
Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present,
@@ -421,19 +426,19 @@ class EndToEndRoomKeyStore(SQLBaseStore):
else:
this_version = version
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": this_version},
)
- return self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version},
updatevalues={"deleted": 1},
)
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
)
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 20698bfd16..fba3098ea2 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,36 +14,79 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List
+import abc
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
-from six import iteritems
-
-from canonicaljson import encode_canonical_json, json
+import attr
+from canonicaljson import encode_canonical_json
from twisted.enterprise.adbapi import Connection
-from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
+from synapse.storage.types import Cursor
+from synapse.types import JsonDict
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
+if TYPE_CHECKING:
+ from synapse.handlers.e2e_keys import SignatureListItem
+
+
+@attr.s
+class DeviceKeyLookupResult:
+ """The type returned by get_e2e_device_keys_and_signatures"""
+
+ display_name = attr.ib(type=Optional[str])
+
+ # the key data from e2e_device_keys_json. Typically includes fields like
+ # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
+ # key) and "signatures" (a map from (user id) to (key id/device_id) to signature.)
+ keys = attr.ib(type=Optional[JsonDict])
+
class EndToEndKeyWorkerStore(SQLBaseStore):
+ async def get_e2e_device_keys_for_federation_query(
+ self, user_id: str
+ ) -> Tuple[int, List[JsonDict]]:
+ """Get all devices (with any device keys) for a user
+
+ Returns:
+ (stream_id, devices)
+ """
+ now_stream_id = self.get_device_stream_token()
+
+ devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
+
+ if devices:
+ user_devices = devices[user_id]
+ results = []
+ for device_id, device in user_devices.items():
+ result = {"device_id": device_id}
+
+ keys = device.keys
+ if keys:
+ result["keys"] = keys
+
+ device_display_name = device.display_name
+ if device_display_name:
+ result["device_display_name"] = device_display_name
+
+ results.append(result)
+
+ return now_stream_id, results
+
+ return now_stream_id, []
+
@trace
- @defer.inlineCallbacks
- def get_e2e_device_keys(
- self, query_list, include_all_devices=False, include_deleted_devices=False
- ):
- """Fetch a list of device keys.
+ async def get_e2e_device_keys_for_cs_api(
+ self, query_list: List[Tuple[str, Optional[str]]]
+ ) -> Dict[str, Dict[str, JsonDict]]:
+ """Fetch a list of device keys, formatted suitably for the C/S API.
Args:
query_list(list): List of pairs of user_ids and device_ids.
- include_all_devices (bool): whether to include entries for devices
- that don't have device keys
- include_deleted_devices (bool): whether to include null entries for
- devices which no longer exist (but were in the query_list).
- This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
key data. The key data will be a dict in the same format as the
@@ -53,45 +96,103 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
if not query_list:
return {}
- results = yield self.db.runInteraction(
- "get_e2e_device_keys",
- self._get_e2e_device_keys_txn,
- query_list,
- include_all_devices,
- include_deleted_devices,
- )
+ results = await self.get_e2e_device_keys_and_signatures(query_list)
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
rv = {}
- for user_id, device_keys in iteritems(results):
+ for user_id, device_keys in results.items():
rv[user_id] = {}
- for device_id, device_info in iteritems(device_keys):
- r = db_to_json(device_info.pop("key_json"))
+ for device_id, device_info in device_keys.items():
+ r = device_info.keys
r["unsigned"] = {}
- display_name = device_info["device_display_name"]
+ display_name = device_info.display_name
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
- if "signatures" in device_info:
- for sig_user_id, sigs in device_info["signatures"].items():
- r.setdefault("signatures", {}).setdefault(
- sig_user_id, {}
- ).update(sigs)
rv[user_id][device_id] = r
return rv
@trace
- def _get_e2e_device_keys_txn(
- self, txn, query_list, include_all_devices=False, include_deleted_devices=False
- ):
+ async def get_e2e_device_keys_and_signatures(
+ self,
+ query_list: List[Tuple[str, Optional[str]]],
+ include_all_devices: bool = False,
+ include_deleted_devices: bool = False,
+ ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+ """Fetch a list of device keys
+
+ Any cross-signatures made on the keys by the owner of the device are also
+ included.
+
+ The cross-signatures are added to the `signatures` field within the `keys`
+ object in the response.
+
+ Args:
+ query_list: List of pairs of user_ids and device_ids. Device id can be None
+ to indicate "all devices for this user"
+
+ include_all_devices: whether to return devices without device keys
+
+ include_deleted_devices: whether to include null entries for
+ devices which no longer exist (but were in the query_list).
+ This option only takes effect if include_all_devices is true.
+
+ Returns:
+ Dict mapping from user-id to dict mapping from device_id to
+ key data.
+ """
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)
+ result = await self.db_pool.runInteraction(
+ "get_e2e_device_keys",
+ self._get_e2e_device_keys_txn,
+ query_list,
+ include_all_devices,
+ include_deleted_devices,
+ )
+
+ # get the (user_id, device_id) tuples to look up cross-signatures for
+ signature_query = (
+ (user_id, device_id)
+ for user_id, dev in result.items()
+ for device_id, d in dev.items()
+ if d is not None and d.keys is not None
+ )
+
+ for batch in batch_iter(signature_query, 50):
+ cross_sigs_result = await self.db_pool.runInteraction(
+ "get_e2e_cross_signing_signatures",
+ self._get_e2e_cross_signing_signatures_for_devices_txn,
+ batch,
+ )
+
+ # add each cross-signing signature to the correct device in the result dict.
+ for (user_id, key_id, device_id, signature) in cross_sigs_result:
+ target_device_result = result[user_id][device_id]
+ target_device_signatures = target_device_result.keys.setdefault(
+ "signatures", {}
+ )
+ signing_user_signatures = target_device_signatures.setdefault(
+ user_id, {}
+ )
+ signing_user_signatures[key_id] = signature
+
+ log_kv(result)
+ return result
+
+ def _get_e2e_device_keys_txn(
+ self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+ ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+ """Get information on devices from the database
+
+ The results include the device's keys and self-signatures, but *not* any
+ cross-signing signatures which have been added subsequently (for which, see
+ get_e2e_device_keys_and_signatures)
+ """
query_clauses = []
query_params = []
- signature_query_clauses = []
- signature_query_params = []
if include_all_devices is False:
include_deleted_devices = False
@@ -102,24 +203,16 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for (user_id, device_id) in query_list:
query_clause = "user_id = ?"
query_params.append(user_id)
- signature_query_clause = "target_user_id = ?"
- signature_query_params.append(user_id)
if device_id is not None:
query_clause += " AND device_id = ?"
query_params.append(device_id)
- signature_query_clause += " AND target_device_id = ?"
- signature_query_params.append(device_id)
-
- signature_query_clause += " AND user_id = ?"
- signature_query_params.append(user_id)
query_clauses.append(query_clause)
- signature_query_clauses.append(signature_query_clause)
sql = (
"SELECT user_id, device_id, "
- " d.display_name AS device_display_name, "
+ " d.display_name, "
" k.key_json"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
@@ -130,54 +223,53 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(sql, query_params)
- rows = self.db.cursor_to_dict(txn)
- result = {}
- for row in rows:
+ result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
+ for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices:
- deleted_devices.remove((row["user_id"], row["device_id"]))
- result.setdefault(row["user_id"], {})[row["device_id"]] = row
+ deleted_devices.remove((user_id, device_id))
+ result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
+ display_name, db_to_json(key_json) if key_json else None
+ )
if include_deleted_devices:
for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None
- # get signatures on the device
- signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
- " OR ".join("(" + q + ")" for q in signature_query_clauses)
- )
+ return result
- txn.execute(signature_sql, signature_query_params)
- rows = self.db.cursor_to_dict(txn)
-
- # add each cross-signing signature to the correct device in the result dict.
- for row in rows:
- signing_user_id = row["user_id"]
- signing_key_id = row["key_id"]
- target_user_id = row["target_user_id"]
- target_device_id = row["target_device_id"]
- signature = row["signature"]
-
- target_user_result = result.get(target_user_id)
- if not target_user_result:
- continue
+ def _get_e2e_cross_signing_signatures_for_devices_txn(
+ self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
+ ) -> List[Tuple[str, str, str, str]]:
+ """Get cross-signing signatures for a given list of devices
- target_device_result = target_user_result.get(target_device_id)
- if not target_device_result:
- # note that target_device_result will be None for deleted devices.
- continue
+ Returns signatures made by the owners of the devices.
- target_device_signatures = target_device_result.setdefault("signatures", {})
- signing_user_signatures = target_device_signatures.setdefault(
- signing_user_id, {}
+ Returns: a list of results; each entry in the list is a tuple of
+ (user_id, key_id, target_device_id, signature).
+ """
+ signature_query_clauses = []
+ signature_query_params = []
+
+ for (user_id, device_id) in device_query:
+ signature_query_clauses.append(
+ "target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
- signing_user_signatures[signing_key_id] = signature
+ signature_query_params.extend([user_id, device_id, user_id])
- log_kv(result)
- return result
+ signature_sql = """
+ SELECT user_id, key_id, target_device_id, signature
+ FROM e2e_cross_signing_signatures WHERE %s
+ """ % (
+ " OR ".join("(" + q + ")" for q in signature_query_clauses)
+ )
- @defer.inlineCallbacks
- def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+ txn.execute(signature_sql, signature_query_params)
+ return txn.fetchall()
+
+ async def get_e2e_one_time_keys(
+ self, user_id: str, device_id: str, key_ids: List[str]
+ ) -> Dict[Tuple[str, str], str]:
"""Retrieve a number of one-time keys for a user
Args:
@@ -187,11 +279,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
retrieve
Returns:
- deferred resolving to Dict[(str, str), str]: map from (algorithm,
- key_id) to json string for key
+ A map from (algorithm, key_id) to json string for key
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
iterable=key_ids,
@@ -203,17 +294,21 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
return result
- @defer.inlineCallbacks
- def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+ async def add_e2e_one_time_keys(
+ self,
+ user_id: str,
+ device_id: str,
+ time_now: int,
+ new_keys: Iterable[Tuple[str, str, str]],
+ ) -> None:
"""Insert some new one time keys for a device. Errors if any of the
keys already exist.
Args:
- user_id(str): id of user to get keys for
- device_id(str): id of device to get keys for
- time_now(long): insertion time to record (ms since epoch)
- new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
- (algorithm, key_id, key json)
+ user_id: id of user to get keys for
+ device_id: id of device to get keys for
+ time_now: insertion time to record (ms since epoch)
+ new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
"""
def _add_e2e_one_time_keys(txn):
@@ -224,7 +319,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set.
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
values=[
@@ -243,15 +338,17 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
)
@cached(max_entries=10000)
- def count_e2e_one_time_keys(self, user_id, device_id):
+ async def count_e2e_one_time_keys(
+ self, user_id: str, device_id: str
+ ) -> Dict[str, int]:
""" Count the number of one time keys the server has for a device
Returns:
- Dict mapping from algorithm to number of keys for that algorithm.
+ A mapping from algorithm to number of keys for that algorithm.
"""
def _count_e2e_one_time_keys(txn):
@@ -266,26 +363,27 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result[algorithm] = key_count
return result
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
- @defer.inlineCallbacks
- def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
+ async def get_e2e_cross_signing_key(
+ self, user_id: str, key_type: str, from_user_id: Optional[str] = None
+ ) -> Optional[dict]:
"""Returns a user's cross-signing key.
Args:
- user_id (str): the user whose key is being requested
- key_type (str): the type of key that is being requested: either 'master'
+ user_id: the user whose key is being requested
+ key_type: the type of key that is being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
- from_user_id (str): if specified, signatures made by this user on
+ from_user_id: if specified, signatures made by this user on
the self-signing key will be included in the result
Returns:
dict of the key data or None if not found
"""
- res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
+ res = await self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
user_keys = res.get(user_id)
if not user_keys:
return None
@@ -303,7 +401,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
list_name="user_ids",
num_args=1,
)
- def _get_bare_e2e_cross_signing_keys_bulk(
+ async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
@@ -311,16 +409,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
the signatures for the calling user need to be fetched.
Args:
- user_ids (list[str]): the users whose keys are being requested
+ user_ids: the users whose keys are being requested
Returns:
- dict[str, dict[str, dict]]: mapping from user ID to key type to key
- data. If a user's cross-signing keys were not found, either
- their user ID will not be in the dict, or their user ID will map
- to None.
+ A mapping from user ID to key type to key data. If a user's cross-signing
+ keys were not found, either their user ID will not be in the dict, or
+ their user ID will map to None.
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
@@ -363,12 +460,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(sql, params)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
for row in rows:
user_id = row["user_id"]
key_type = row["keytype"]
- key = json.loads(row["keydata"])
+ key = db_to_json(row["keydata"])
user_info = result.setdefault(user_id, {})
user_info[key_type] = key
@@ -422,7 +519,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
query_params.extend(item)
txn.execute(sql, query_params)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
# and add the signatures to the appropriate keys
for row in rows:
@@ -451,28 +548,26 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return keys
- @defer.inlineCallbacks
- def get_e2e_cross_signing_keys_bulk(
- self, user_ids: List[str], from_user_id: str = None
- ) -> defer.Deferred:
+ async def get_e2e_cross_signing_keys_bulk(
+ self, user_ids: List[str], from_user_id: Optional[str] = None
+ ) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users.
Args:
- user_ids (list[str]): the users whose keys are being requested
- from_user_id (str): if specified, signatures made by this user on
+ user_ids: the users whose keys are being requested
+ from_user_id: if specified, signatures made by this user on
the self-signing keys will be included in the result
Returns:
- Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
- key data. If a user's cross-signing keys were not found, either
- their user ID will not be in the dict, or their user ID will map
- to None.
+ A map of user ID to key type to key data. If a user's cross-signing
+ keys were not found, either their user ID will not be in the dict,
+ or their user ID will map to None.
"""
- result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
+ result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
if from_user_id:
- result = yield self.db.runInteraction(
+ result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn,
result,
@@ -481,39 +576,73 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return result
- def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
- """Return a list of changes from the user signature stream to notify remotes.
+ async def get_all_user_signature_changes_for_remotes(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for groups replication stream.
+
Note that the user signature stream represents when a user signs their
device with their user-signing key, which is not published to other
users or servers, so no `destination` is needed in the returned
list. However, this is needed to poke workers.
Args:
- from_key (int): the stream ID to start at (exclusive)
- to_key (int): the stream ID to end at (inclusive)
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
Returns:
- Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
- """
- sql = """
- SELECT stream_id, from_user_id AS user_id
- FROM user_signature_stream
- WHERE ? < stream_id AND stream_id <= ?
- ORDER BY stream_id ASC
- LIMIT ?
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
- return self.db.execute(
+
+ if last_id == current_id:
+ return [], current_id, False
+
+ def _get_all_user_signature_changes_for_remotes_txn(txn):
+ sql = """
+ SELECT stream_id, from_user_id AS user_id
+ FROM user_signature_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, current_id, limit))
+
+ updates = [(row[0], (row[1:])) for row in txn]
+
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return await self.db_pool.runInteraction(
"get_all_user_signature_changes_for_remotes",
- None,
- sql,
- from_key,
- to_key,
- limit,
+ _get_all_user_signature_changes_for_remotes_txn,
)
+ @abc.abstractmethod
+ def get_device_stream_token(self) -> int:
+ """Get the current stream id from the _device_list_id_gen"""
+ ...
+
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
- def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+ async def set_e2e_device_keys(
+ self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
+ ) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
@@ -524,7 +653,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
set_tag("time_now", time_now)
set_tag("device_keys", device_keys)
- old_key_json = self.db.simple_select_one_onecol_txn(
+ old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -540,7 +669,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"Message": "Device key already stored."})
return False
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -549,10 +678,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"message": "Device keys stored."})
return True
- return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
+ return await self.db_pool.runInteraction(
+ "set_e2e_device_keys", _set_e2e_device_keys_txn
+ )
- def claim_e2e_one_time_keys(self, query_list):
- """Take a list of one time keys out of the database"""
+ async def claim_e2e_one_time_keys(
+ self, query_list: Iterable[Tuple[str, str, str]]
+ ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+ """Take a list of one time keys out of the database.
+
+ Args:
+ query_list: An iterable of tuples of (user ID, device ID, algorithm).
+
+ Returns:
+ A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+ """
@trace
def _claim_e2e_one_time_keys(txn):
@@ -588,11 +728,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
return result
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
- def delete_e2e_keys_by_device(self, user_id, device_id):
+ async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn):
log_kv(
{
@@ -601,12 +741,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"user_id": user_id,
}
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="e2e_one_time_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -615,11 +755,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
- def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
+ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
"""Set a user's cross-signing key.
Args:
@@ -629,6 +769,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
key (dict): the key data
+ stream_id (int)
"""
# the 'key' dict will look something like:
# {
@@ -654,7 +795,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
# We only need to do this for local users, since remote servers should be
# responsible for checking this for their own users.
if self.hs.is_mine_id(user_id):
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
"devices",
values={
@@ -666,23 +807,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
# and finally, store the key itself
- with self._cross_signing_id_gen.get_next() as stream_id:
- self.db.simple_insert_txn(
- txn,
- "e2e_cross_signing_keys",
- values={
- "user_id": user_id,
- "keytype": key_type,
- "keydata": json.dumps(key),
- "stream_id": stream_id,
- },
- )
+ self.db_pool.simple_insert_txn(
+ txn,
+ "e2e_cross_signing_keys",
+ values={
+ "user_id": user_id,
+ "keytype": key_type,
+ "keydata": json_encoder.encode(key),
+ "stream_id": stream_id,
+ },
+ )
self._invalidate_cache_and_stream(
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
)
- def set_e2e_cross_signing_key(self, user_id, key_type, key):
+ async def set_e2e_cross_signing_key(self, user_id, key_type, key):
"""Set a user's cross-signing key.
Args:
@@ -690,22 +830,27 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key_type (str): the type of cross-signing key to set
key (dict): the key data
"""
- return self.db.runInteraction(
- "add_e2e_cross_signing_key",
- self._set_e2e_cross_signing_key_txn,
- user_id,
- key_type,
- key,
- )
- def store_e2e_cross_signing_signatures(self, user_id, signatures):
+ with await self._cross_signing_id_gen.get_next() as stream_id:
+ return await self.db_pool.runInteraction(
+ "add_e2e_cross_signing_key",
+ self._set_e2e_cross_signing_key_txn,
+ user_id,
+ key_type,
+ key,
+ stream_id,
+ )
+
+ async def store_e2e_cross_signing_signatures(
+ self, user_id: str, signatures: "Iterable[SignatureListItem]"
+ ) -> None:
"""Stores cross-signing signatures.
Args:
- user_id (str): the user who made the signatures
- signatures (iterable[SignatureListItem]): signatures to add
+ user_id: the user who made the signatures
+ signatures: signatures to add
"""
- return self.db.simple_insert_many(
+ await self.db_pool.simple_insert_many(
"e2e_cross_signing_signatures",
[
{
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 24ce8c4330..0b69aa6a94 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -14,18 +14,17 @@
# limitations under the License.
import itertools
import logging
-from typing import Dict, List, Optional, Set, Tuple
-
-from six.moves.queue import Empty, PriorityQueue
-
-from twisted.internet import defer
+from queue import Empty, PriorityQueue
+from typing import Dict, Iterable, List, Set, Tuple
from synapse.api.errors import StoreError
+from synapse.events import EventBase
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.types import Collection
from synapse.util.caches.descriptors import cached
from synapse.util.iterutils import batch_iter
@@ -33,57 +32,51 @@ logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
- def get_auth_chain(self, event_ids, include_given=False):
+ async def get_auth_chain(
+ self, event_ids: Collection[str], include_given: bool = False
+ ) -> List[EventBase]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
- event_ids (list): state events
- include_given (bool): include the given events in result
+ event_ids: state events
+ include_given: include the given events in result
Returns:
list of events
"""
- return self.get_auth_chain_ids(
+ event_ids = await self.get_auth_chain_ids(
event_ids, include_given=include_given
- ).addCallback(self.get_events_as_list)
-
- def get_auth_chain_ids(
- self,
- event_ids: List[str],
- include_given: bool = False,
- ignore_events: Optional[Set[str]] = None,
- ):
+ )
+ return await self.get_events_as_list(event_ids)
+
+ async def get_auth_chain_ids(
+ self, event_ids: Collection[str], include_given: bool = False,
+ ) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
event_ids: state events
include_given: include the given events in result
- ignore_events: Set of events to exclude from the returned auth
- chain. This is useful if the caller will just discard the
- given events anyway, and saves us from figuring out their auth
- chains if not required.
Returns:
- list of event_ids
+ An awaitable which resolve to a list of event_ids
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
event_ids,
include_given,
- ignore_events,
)
- def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
- if ignore_events is None:
- ignore_events = set()
-
+ def _get_auth_chain_ids_txn(
+ self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
+ ) -> List[str]:
if include_given:
results = set(event_ids)
else:
results = set()
- base_sql = "SELECT auth_id FROM event_auth WHERE "
+ base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
front = set(event_ids)
while front:
@@ -95,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(base_sql + clause, args)
new_front.update(r[0] for r in txn)
- new_front -= ignore_events
new_front -= results
front = new_front
@@ -103,7 +95,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
- def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+ async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@@ -112,10 +104,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
chain.
Returns:
- Deferred[Set[str]]
+ The set of the difference in auth chains.
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
state_sets,
@@ -260,13 +252,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
- def get_oldest_events_in_room(self, room_id):
- return self.db.runInteraction(
- "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
- )
-
- def get_oldest_events_with_depth_in_room(self, room_id):
- return self.db.runInteraction(
+ async def get_oldest_events_with_depth_in_room(self, room_id):
+ return await self.db_pool.runInteraction(
"get_oldest_events_with_depth_in_room",
self.get_oldest_events_with_depth_in_room_txn,
room_id,
@@ -287,17 +274,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return dict(txn)
- @defer.inlineCallbacks
- def get_max_depth_of(self, event_ids):
+ async def get_max_depth_of(self, event_ids: List[str]) -> int:
"""Returns the max depth of a set of event IDs
Args:
- event_ids (list[str])
-
- Returns
- Deferred[int]
+ event_ids: The event IDs to calculate the max depth of.
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
@@ -310,15 +293,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else:
return max(row["depth"] for row in rows)
- def _get_oldest_events_in_room_txn(self, txn, room_id):
- return self.db.simple_select_onecol_txn(
- txn,
- table="event_backward_extremities",
- keyvalues={"room_id": room_id},
- retcol="event_id",
- )
-
- def get_prev_events_for_room(self, room_id: str):
+ async def get_prev_events_for_room(self, room_id: str) -> List[str]:
"""
Gets a subset of the current forward extremities in the given room.
@@ -326,14 +301,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
events which refer to hundreds of prev_events.
Args:
- room_id (str): room_id
+ room_id: room_id
Returns:
- Deferred[List[str]]: the event ids of the forward extremites
+ The event ids of the forward extremities.
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
)
@@ -353,17 +328,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return [row[0] for row in txn]
- def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter):
+ async def get_rooms_with_many_extremities(
+ self, min_count: int, limit: int, room_id_filter: Iterable[str]
+ ) -> List[str]:
"""Get the top rooms with at least N extremities.
Args:
- min_count (int): The minimum number of extremities
- limit (int): The maximum number of rooms to return.
- room_id_filter (iterable[str]): room_ids to exclude from the results
+ min_count: The minimum number of extremities
+ limit: The maximum number of rooms to return.
+ room_id_filter: room_ids to exclude from the results
Returns:
- Deferred[list]: At most `limit` room IDs that have at least
- `min_count` extremities, sorted by extremity count.
+ At most `limit` room IDs that have at least `min_count` extremities,
+ sorted by extremity count.
"""
def _get_rooms_with_many_extremities_txn(txn):
@@ -388,28 +365,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, query_args)
return [room_id for room_id, in txn]
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
)
@cached(max_entries=5000, iterable=True)
- def get_latest_event_ids_in_room(self, room_id):
- return self.db.simple_select_onecol(
+ async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
+ return await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
desc="get_latest_event_ids_in_room",
)
- def get_min_depth(self, room_id):
- """ For hte given room, get the minimum depth we have seen for it.
+ async def get_min_depth(self, room_id: str) -> int:
+ """For the given room, get the minimum depth we have seen for it.
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id
)
def _get_min_depth_interaction(self, txn, room_id):
- min_depth = self.db.simple_select_one_onecol_txn(
+ min_depth = self.db_pool.simple_select_one_onecol_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
@@ -419,7 +396,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return int(min_depth) if min_depth is not None else None
- def get_forward_extremeties_for_room(self, room_id, stream_ordering):
+ async def get_forward_extremeties_for_room(
+ self, room_id: str, stream_ordering: int
+ ) -> List[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -427,11 +406,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
stream_orderings from that point.
Args:
- room_id (str):
- stream_ordering (int):
+ room_id:
+ stream_ordering:
Returns:
- deferred, which resolves to a list of event_ids
+ A list of event_ids
"""
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
@@ -447,10 +426,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
if last_change > self.stream_ordering_month_ago:
stream_ordering = min(last_change, stream_ordering)
- return self._get_forward_extremeties_for_room(room_id, stream_ordering)
+ return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
@cached(max_entries=5000, num_args=2)
- def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
+ async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -475,31 +454,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
- def get_backfill_events(self, room_id, event_list, limit):
+ async def get_backfill_events(self, room_id: str, event_list: list, limit: int):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
Args:
- txn
- room_id (str)
- event_list (list)
- limit (int)
+ room_id
+ event_list
+ limit
"""
- return (
- self.db.runInteraction(
- "get_backfill_events",
- self._get_backfill_events,
- room_id,
- event_list,
- limit,
- )
- .addCallback(self.get_events_as_list)
- .addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
+ event_ids = await self.db_pool.runInteraction(
+ "get_backfill_events",
+ self._get_backfill_events,
+ room_id,
+ event_list,
+ limit,
)
+ events = await self.get_events_as_list(event_ids)
+ return sorted(events, key=lambda e: -e.depth)
def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
@@ -521,7 +497,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
queue = PriorityQueue()
for event_id in event_list:
- depth = self.db.simple_select_one_onecol_txn(
+ depth = self.db_pool.simple_select_one_onecol_txn(
txn,
table="events",
keyvalues={"event_id": event_id, "room_id": room_id},
@@ -551,9 +527,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return event_results
- @defer.inlineCallbacks
- def get_missing_events(self, room_id, earliest_events, latest_events, limit):
- ids = yield self.db.runInteraction(
+ async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
+ ids = await self.db_pool.runInteraction(
"get_missing_events",
self._get_missing_events,
room_id,
@@ -561,8 +536,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events,
limit,
)
- events = yield self.get_events_as_list(ids)
- return events
+ return await self.get_events_as_list(ids)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
@@ -596,17 +570,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_results.reverse()
return event_results
- @defer.inlineCallbacks
- def get_successor_events(self, event_ids):
+ async def get_successor_events(self, event_ids: Iterable[str]) -> List[str]:
"""Fetch all events that have the given events as a prev event
Args:
- event_ids (iterable[str])
-
- Returns:
- Deferred[list[str]]
+ event_ids: The events to use as the previous events.
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=event_ids,
@@ -629,10 +599,10 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(EventFederationStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
@@ -659,13 +629,13 @@ class EventFederationStore(EventFederationWorkerStore):
return run_as_background_process(
"delete_old_forward_extrem_cache",
- self.db.runInteraction,
+ self.db_pool.runInteraction,
"_delete_old_forward_extrem_cache",
_delete_old_forward_extrem_cache_txn,
)
- def clean_room_for_join(self, room_id):
- return self.db.runInteraction(
+ async def clean_room_for_join(self, room_id):
+ return await self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)
@@ -675,8 +645,7 @@ class EventFederationStore(EventFederationWorkerStore):
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
- @defer.inlineCallbacks
- def _background_delete_non_state_event_auth(self, progress, batch_size):
+ async def _background_delete_non_state_event_auth(self, progress, batch_size):
def delete_event_auth(txn):
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
max_stream_id = progress.get("max_stream_id_exclusive")
@@ -709,17 +678,19 @@ class EventFederationStore(EventFederationWorkerStore):
"max_stream_id_exclusive": min_stream_id,
}
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, self.EVENT_AUTH_STATE_ONLY, new_progress
)
return min_stream_id >= target_min_stream_id
- result = yield self.db.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_AUTH_STATE_ONLY, delete_event_auth
)
if not result:
- yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY)
+ await self.db_pool.updates._end_background_update(
+ self.EVENT_AUTH_STATE_ONLY
+ )
return batch_size
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 0321274de2..5233ed83e2 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -15,17 +15,15 @@
# limitations under the License.
import logging
+from typing import Dict, List, Optional, Tuple, Union
-from six import iteritems
-
-from canonicaljson import json
-
-from twisted.internet import defer
+import attr
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction, SQLBaseStore
-from synapse.storage.database import Database
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.util import json_encoder
+from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -53,14 +51,14 @@ def _serialize_action(actions, is_highlight):
else:
if actions == DEFAULT_NOTIF_ACTION:
return ""
- return json.dumps(actions)
+ return json_encoder.encode(actions)
def _deserialize_action(actions, is_highlight):
"""Custom deserializer for actions. This allows us to "compress" common actions
"""
if actions:
- return json.loads(actions)
+ return db_to_json(actions)
if is_highlight:
return DEFAULT_HIGHLIGHT_ACTION
@@ -69,7 +67,7 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
@@ -90,114 +88,141 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._rotate_delay = 3
self._rotate_count = 10000
- @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
- def get_unread_event_push_actions_by_room_for_user(
- self, room_id, user_id, last_read_event_id
- ):
- ret = yield self.db.runInteraction(
+ @cached(num_args=3, tree=True, max_entries=5000)
+ async def get_unread_event_push_actions_by_room_for_user(
+ self, room_id: str, user_id: str, last_read_event_id: Optional[str],
+ ) -> Dict[str, int]:
+ """Get the notification count, the highlight count and the unread message count
+ for a given user in a given room after the given read receipt.
+
+ Note that this function assumes the user to be a current member of the room,
+ since it's either called by the sync handler to handle joined room entries, or by
+ the HTTP pusher to calculate the badge of unread joined rooms.
+
+ Args:
+ room_id: The room to retrieve the counts in.
+ user_id: The user to retrieve the counts for.
+ last_read_event_id: The event associated with the latest read receipt for
+ this user in this room. None if no receipt for this user in this room.
+
+ Returns
+ A dict containing the counts mentioned earlier in this docstring,
+ respectively under the keys "notify_count", "highlight_count" and
+ "unread_count".
+ """
+ return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
user_id,
last_read_event_id,
)
- return ret
def _get_unread_counts_by_receipt_txn(
- self, txn, room_id, user_id, last_read_event_id
+ self, txn, room_id, user_id, last_read_event_id,
):
- sql = (
- "SELECT stream_ordering"
- " FROM events"
- " WHERE room_id = ? AND event_id = ?"
- )
- txn.execute(sql, (room_id, last_read_event_id))
- results = txn.fetchall()
- if len(results) == 0:
- return {"notify_count": 0, "highlight_count": 0}
+ stream_ordering = None
+
+ if last_read_event_id is not None:
+ stream_ordering = self.get_stream_id_for_event_txn(
+ txn, last_read_event_id, allow_none=True,
+ )
+
+ if stream_ordering is None:
+ # Either last_read_event_id is None, or it's an event we don't have (e.g.
+ # because it's been purged), in which case retrieve the stream ordering for
+ # the latest membership event from this user in this room (which we assume is
+ # a join).
+ event_id = self.db_pool.simple_select_one_onecol_txn(
+ txn=txn,
+ table="local_current_membership",
+ keyvalues={"room_id": room_id, "user_id": user_id},
+ retcol="event_id",
+ )
- stream_ordering = results[0][0]
+ stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering
)
def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
-
- # First get number of notifications.
- # We don't need to put a notif=1 clause as all rows always have
- # notif=1
sql = (
- "SELECT count(*)"
+ "SELECT"
+ " COUNT(CASE WHEN notif = 1 THEN 1 END),"
+ " COUNT(CASE WHEN highlight = 1 THEN 1 END),"
+ " COUNT(CASE WHEN unread = 1 THEN 1 END)"
" FROM event_push_actions ea"
- " WHERE"
- " user_id = ?"
- " AND room_id = ?"
- " AND stream_ordering > ?"
+ " WHERE user_id = ?"
+ " AND room_id = ?"
+ " AND stream_ordering > ?"
)
txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
- notify_count = row[0] if row else 0
+
+ (notif_count, highlight_count, unread_count) = (0, 0, 0)
+
+ if row:
+ (notif_count, highlight_count, unread_count) = row
txn.execute(
"""
- SELECT notif_count FROM event_push_summary
- WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
- """,
+ SELECT notif_count, unread_count FROM event_push_summary
+ WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+ """,
(room_id, user_id, stream_ordering),
)
- rows = txn.fetchall()
- if rows:
- notify_count += rows[0][0]
+ row = txn.fetchone()
- # Now get the number of highlights
- sql = (
- "SELECT count(*)"
- " FROM event_push_actions ea"
- " WHERE"
- " highlight = 1"
- " AND user_id = ?"
- " AND room_id = ?"
- " AND stream_ordering > ?"
- )
+ if row:
+ notif_count += row[0]
- txn.execute(sql, (user_id, room_id, stream_ordering))
- row = txn.fetchone()
- highlight_count = row[0] if row else 0
+ if row[1] is not None:
+ # The unread_count column of event_push_summary is NULLable, so we need
+ # to make sure we don't try increasing the unread counts if it's NULL
+ # for this row.
+ unread_count += row[1]
- return {"notify_count": notify_count, "highlight_count": highlight_count}
+ return {
+ "notify_count": notif_count,
+ "unread_count": unread_count,
+ "highlight_count": highlight_count,
+ }
- @defer.inlineCallbacks
- def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
+ async def get_push_action_users_in_range(
+ self, min_stream_ordering, max_stream_ordering
+ ):
def f(txn):
sql = (
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
- " stream_ordering >= ? AND stream_ordering <= ?"
+ " stream_ordering >= ? AND stream_ordering <= ? AND notif = 1"
)
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn]
- ret = yield self.db.runInteraction("get_push_action_users_in_range", f)
+ ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f)
return ret
- @defer.inlineCallbacks
- def get_unread_push_actions_for_user_in_range_for_http(
- self, user_id, min_stream_ordering, max_stream_ordering, limit=20
- ):
+ async def get_unread_push_actions_for_user_in_range_for_http(
+ self,
+ user_id: str,
+ min_stream_ordering: int,
+ max_stream_ordering: int,
+ limit: int = 20,
+ ) -> List[dict]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the httppusher.
Args:
- user_id (str): The user to fetch push actions for.
- min_stream_ordering(int): The exclusive lower bound on the
+ user_id: The user to fetch push actions for.
+ min_stream_ordering: The exclusive lower bound on the
stream ordering of event push actions to fetch.
- max_stream_ordering(int): The inclusive upper bound on the
+ max_stream_ordering: The inclusive upper bound on the
stream ordering of event push actions to fetch.
- limit (int): The maximum number of rows to return.
+ limit: The maximum number of rows to return.
Returns:
- A promise which resolves to a list of dicts with the keys "event_id",
- "room_id", "stream_ordering", "actions".
+ A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions".
The list will be ordered by ascending stream_ordering.
The list will have between 0~limit entries.
"""
@@ -224,13 +249,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall()
- after_read_receipt = yield self.db.runInteraction(
+ after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
@@ -252,13 +278,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall()
- no_read_receipt = yield self.db.runInteraction(
+ no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
@@ -282,23 +309,25 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# one of the subqueries may have hit the limit.
return notifs[:limit]
- @defer.inlineCallbacks
- def get_unread_push_actions_for_user_in_range_for_email(
- self, user_id, min_stream_ordering, max_stream_ordering, limit=20
- ):
+ async def get_unread_push_actions_for_user_in_range_for_email(
+ self,
+ user_id: str,
+ min_stream_ordering: int,
+ max_stream_ordering: int,
+ limit: int = 20,
+ ) -> List[dict]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the emailpusher
Args:
- user_id (str): The user to fetch push actions for.
- min_stream_ordering(int): The exclusive lower bound on the
+ user_id: The user to fetch push actions for.
+ min_stream_ordering: The exclusive lower bound on the
stream ordering of event push actions to fetch.
- max_stream_ordering(int): The inclusive upper bound on the
+ max_stream_ordering: The inclusive upper bound on the
stream ordering of event push actions to fetch.
- limit (int): The maximum number of rows to return.
+ limit: The maximum number of rows to return.
Returns:
- A promise which resolves to a list of dicts with the keys "event_id",
- "room_id", "stream_ordering", "actions", "received_ts".
+ A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts".
The list will be ordered by descending received_ts.
The list will have between 0~limit entries.
"""
@@ -324,13 +353,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall()
- after_read_receipt = yield self.db.runInteraction(
+ after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
)
@@ -352,13 +382,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
+ " AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall()
- no_read_receipt = yield self.db.runInteraction(
+ no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
)
@@ -383,62 +414,66 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# Now return the first `limit`
return notifs[:limit]
- def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
+ async def get_if_maybe_push_in_range_for_user(
+ self, user_id: str, min_stream_ordering: int
+ ) -> bool:
"""A fast check to see if there might be something to push for the
user since the given stream ordering. May return false positives.
Useful to know whether to bother starting a pusher on start up or not.
Args:
- user_id (str)
- min_stream_ordering (int)
+ user_id
+ min_stream_ordering
Returns:
- Deferred[bool]: True if there may be push to process, False if
- there definitely isn't.
+ True if there may be push to process, False if there definitely isn't.
"""
def _get_if_maybe_push_in_range_for_user_txn(txn):
sql = """
SELECT 1 FROM event_push_actions
- WHERE user_id = ? AND stream_ordering > ?
+ WHERE user_id = ? AND stream_ordering > ? AND notif = 1
LIMIT 1
"""
txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone())
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_if_maybe_push_in_range_for_user",
_get_if_maybe_push_in_range_for_user_txn,
)
- def add_push_actions_to_staging(self, event_id, user_id_actions):
+ async def add_push_actions_to_staging(
+ self,
+ event_id: str,
+ user_id_actions: Dict[str, List[Union[dict, str]]],
+ count_as_unread: bool,
+ ) -> None:
"""Add the push actions for the event to the push action staging area.
Args:
- event_id (str)
- user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
- user_id to list of push actions, where an action can either be
- a string or dict.
-
- Returns:
- Deferred
+ event_id
+ user_id_actions: A mapping of user_id to list of push actions, where
+ an action can either be a string or dict.
+ count_as_unread: Whether this event should increment unread counts.
"""
-
if not user_id_actions:
return
# This is a helper function for generating the necessary tuple that
- # can be used to inert into the `event_push_actions_staging` table.
+ # can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(user_id, actions):
is_highlight = 1 if _action_has_highlight(actions) else 0
+ notif = 1 if "notify" in actions else 0
return (
event_id, # event_id column
user_id, # user_id column
_serialize_action(actions, is_highlight), # actions column
- 1, # notif column
+ notif, # notif column
is_highlight, # highlight column
+ int(count_as_unread), # unread column
)
def _add_push_actions_to_staging_txn(txn):
@@ -447,33 +482,29 @@ class EventPushActionsWorkerStore(SQLBaseStore):
sql = """
INSERT INTO event_push_actions_staging
- (event_id, user_id, actions, notif, highlight)
- VALUES (?, ?, ?, ?, ?)
+ (event_id, user_id, actions, notif, highlight, unread)
+ VALUES (?, ?, ?, ?, ?, ?)
"""
txn.executemany(
sql,
(
_gen_entry(user_id, actions)
- for user_id, actions in iteritems(user_id_actions)
+ for user_id, actions in user_id_actions.items()
),
)
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
)
- @defer.inlineCallbacks
- def remove_push_actions_from_staging(self, event_id):
+ async def remove_push_actions_from_staging(self, event_id: str) -> None:
"""Called if we failed to persist the event to ensure that stale push
actions don't build up in the DB
-
- Args:
- event_id (str)
"""
try:
- res = yield self.db.simple_delete(
+ res = await self.db_pool.simple_delete(
table="event_push_actions_staging",
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
@@ -490,7 +521,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _find_stream_orderings_for_times(self):
return run_as_background_process(
"event_push_action_stream_orderings",
- self.db.runInteraction,
+ self.db_pool.runInteraction,
"_find_stream_orderings_for_times",
self._find_stream_orderings_for_times_txn,
)
@@ -511,7 +542,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
)
- def find_first_stream_ordering_after_ts(self, ts):
+ async def find_first_stream_ordering_after_ts(self, ts: int) -> int:
"""Gets the stream ordering corresponding to a given timestamp.
Specifically, finds the stream_ordering of the first event that was
@@ -520,13 +551,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
relatively slow.
Args:
- ts (int): timestamp in millis
+ ts: timestamp in millis
Returns:
- Deferred[int]: stream ordering of the first event received on/after
- the timestamp
+ stream ordering of the first event received on/after the timestamp
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"_find_first_stream_ordering_after_ts_txn",
self._find_first_stream_ordering_after_ts_txn,
ts,
@@ -608,38 +638,39 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return range_end
- @defer.inlineCallbacks
- def get_time_of_last_push_action_before(self, stream_ordering):
+ async def get_time_of_last_push_action_before(self, stream_ordering):
def f(txn):
sql = (
"SELECT e.received_ts"
" FROM event_push_actions AS ep"
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
- " WHERE ep.stream_ordering > ?"
+ " WHERE ep.stream_ordering > ? AND notif = 1"
" ORDER BY ep.stream_ordering ASC"
" LIMIT 1"
)
txn.execute(sql, (stream_ordering,))
return txn.fetchone()
- result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
+ result = await self.db_pool.runInteraction(
+ "get_time_of_last_push_action_before", f
+ )
return result[0] if result else None
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(EventPushActionsStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
self.EPA_HIGHLIGHT_INDEX,
index_name="event_push_actions_u_highlight",
table="event_push_actions",
columns=["user_id", "stream_ordering"],
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"event_push_actions_highlights_index",
index_name="event_push_actions_highlights_index",
table="event_push_actions",
@@ -652,8 +683,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
self._start_rotate_notifs, 30 * 60 * 1000
)
- @defer.inlineCallbacks
- def get_push_actions_for_user(
+ async def get_push_actions_for_user(
self, user_id, before=None, limit=50, only_highlight=False
):
def f(txn):
@@ -678,24 +708,24 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" FROM event_push_actions epa, events e"
" WHERE epa.event_id = e.event_id"
" AND epa.user_id = ? %s"
+ " AND epa.notif = 1"
" ORDER BY epa.stream_ordering DESC"
" LIMIT ?" % (before_clause,)
)
txn.execute(sql, args)
- return self.db.cursor_to_dict(txn)
+ return self.db_pool.cursor_to_dict(txn)
- push_actions = yield self.db.runInteraction("get_push_actions_for_user", f)
+ push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
- @defer.inlineCallbacks
- def get_latest_push_action_stream_ordering(self):
+ async def get_latest_push_action_stream_ordering(self):
def f(txn):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone()
- result = yield self.db.runInteraction(
+ result = await self.db_pool.runInteraction(
"get_latest_push_action_stream_ordering", f
)
return result[0] or 0
@@ -749,8 +779,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
def _start_rotate_notifs(self):
return run_as_background_process("rotate_notifs", self._rotate_notifs)
- @defer.inlineCallbacks
- def _rotate_notifs(self):
+ async def _rotate_notifs(self):
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
return
self._doing_notif_rotation = True
@@ -759,12 +788,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
while True:
logger.info("Rotating notifications")
- caught_up = yield self.db.runInteraction(
+ caught_up = await self.db_pool.runInteraction(
"_rotate_notifs", self._rotate_notifs_txn
)
if caught_up:
break
- yield self.hs.get_clock().sleep(self._rotate_delay)
+ await self.hs.get_clock().sleep(self._rotate_delay)
finally:
self._doing_notif_rotation = False
@@ -773,7 +802,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
the archiving process has caught up or not.
"""
- old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
+ old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@@ -809,7 +838,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
return caught_up
def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
- old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
+ old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@@ -819,49 +848,100 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# Calculate the new counts that should be upserted into event_push_summary
sql = """
SELECT user_id, room_id,
- coalesce(old.notif_count, 0) + upd.notif_count,
+ coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering,
old.user_id
FROM (
- SELECT user_id, room_id, count(*) as notif_count,
+ SELECT user_id, room_id, count(*) as cnt,
max(stream_ordering) as stream_ordering
FROM event_push_actions
WHERE ? <= stream_ordering AND stream_ordering < ?
AND highlight = 0
+ AND %s = 1
GROUP BY user_id, room_id
) AS upd
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
"""
- txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
- rows = txn.fetchall()
+ # First get the count of unread messages.
+ txn.execute(
+ sql % ("unread_count", "unread"),
+ (old_rotate_stream_ordering, rotate_to_stream_ordering),
+ )
+
+ # We need to merge results from the two requests (the one that retrieves the
+ # unread count and the one that retrieves the notifications count) into a single
+ # object because we might not have the same amount of rows in each of them. To do
+ # this, we use a dict indexed on the user ID and room ID to make it easier to
+ # populate.
+ summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary]
+ for row in txn:
+ summaries[(row[0], row[1])] = _EventPushSummary(
+ unread_count=row[2],
+ stream_ordering=row[3],
+ old_user_id=row[4],
+ notif_count=0,
+ )
+
+ # Then get the count of notifications.
+ txn.execute(
+ sql % ("notif_count", "notif"),
+ (old_rotate_stream_ordering, rotate_to_stream_ordering),
+ )
+
+ for row in txn:
+ if (row[0], row[1]) in summaries:
+ summaries[(row[0], row[1])].notif_count = row[2]
+ else:
+ # Because the rules on notifying are different than the rules on marking
+ # a message unread, we might end up with messages that notify but aren't
+ # marked unread, so we might not have a summary for this (user, room)
+ # tuple to complete.
+ summaries[(row[0], row[1])] = _EventPushSummary(
+ unread_count=0,
+ stream_ordering=row[3],
+ old_user_id=row[4],
+ notif_count=row[2],
+ )
- logger.info("Rotating notifications, handling %d rows", len(rows))
+ logger.info("Rotating notifications, handling %d rows", len(summaries))
# If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the
# existing table.
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="event_push_summary",
values=[
{
- "user_id": row[0],
- "room_id": row[1],
- "notif_count": row[2],
- "stream_ordering": row[3],
+ "user_id": user_id,
+ "room_id": room_id,
+ "notif_count": summary.notif_count,
+ "unread_count": summary.unread_count,
+ "stream_ordering": summary.stream_ordering,
}
- for row in rows
- if row[4] is None
+ for ((user_id, room_id), summary) in summaries.items()
+ if summary.old_user_id is None
],
)
txn.executemany(
"""
- UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
+ UPDATE event_push_summary
+ SET notif_count = ?, unread_count = ?, stream_ordering = ?
WHERE user_id = ? AND room_id = ?
""",
- ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
+ (
+ (
+ summary.notif_count,
+ summary.unread_count,
+ summary.stream_ordering,
+ user_id,
+ room_id,
+ )
+ for ((user_id, room_id), summary) in summaries.items()
+ if summary.old_user_id is not None
+ ),
)
txn.execute(
@@ -887,3 +967,15 @@ def _action_has_highlight(actions):
pass
return False
+
+
+@attr.s
+class _EventPushSummary:
+ """Summary of pending event push actions for a given user in a given room.
+ Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
+ """
+
+ unread_count = attr.ib(type=int)
+ stream_ordering = attr.ib(type=int)
+ old_user_id = attr.ib(type=str)
+ notif_count = attr.ib(type=int)
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/databases/main/events.py
index a6572571b4..b3d27a2ee7 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -14,45 +14,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import itertools
import logging
from collections import OrderedDict, namedtuple
-from functools import wraps
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
-
-from six import integer_types, iteritems, text_type
-from six.moves import range
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
import attr
-from canonicaljson import json
from prometheus_client import Counter
-from twisted.internet import defer
-
import synapse.metrics
-from synapse.api.constants import (
- EventContentFields,
- EventTypes,
- Membership,
- RelationTypes,
-)
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
-from synapse.storage._base import make_in_list_sql_clause
-from synapse.storage.data_stores.main.search import SearchEntry
-from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage._base import db_to_json, make_in_list_sql_clause
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
- from synapse.storage.data_stores.main import DataStore
from synapse.server import HomeServer
+ from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -78,27 +65,6 @@ def encode_json(json_object):
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-def _retry_on_integrity_error(func):
- """Wraps a database function so that it gets retried on IntegrityError,
- with `delete_existing=True` passed in.
-
- Args:
- func: function that returns a Deferred and accepts a `delete_existing` arg
- """
-
- @wraps(func)
- @defer.inlineCallbacks
- def f(self, *args, **kwargs):
- try:
- res = yield func(self, *args, delete_existing=False, **kwargs)
- except self.database_engine.module.IntegrityError:
- logger.exception("IntegrityError, retrying.")
- res = yield func(self, *args, delete_existing=True, **kwargs)
- return res
-
- return f
-
-
@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -123,9 +89,11 @@ class PersistEventsStore:
Note: This is not part of the `DataStore` mixin.
"""
- def __init__(self, hs: "HomeServer", db: Database, main_data_store: "DataStore"):
+ def __init__(
+ self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore"
+ ):
self.hs = hs
- self.db = db
+ self.db_pool = db
self.store = main_data_store
self.database_engine = db.engine
self._clock = hs.get_clock()
@@ -143,17 +111,14 @@ class PersistEventsStore:
hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate EventsStore on master"
- @_retry_on_integrity_error
- @defer.inlineCallbacks
- def _persist_events_and_state_updates(
+ async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
backfilled: bool = False,
- delete_existing: bool = False,
- ):
+ ) -> None:
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -166,10 +131,9 @@ class PersistEventsStore:
new_forward_extremities: Map from room_id to list of event IDs
that are the new forward extremities of the room.
backfilled
- delete_existing
Returns:
- Deferred: resolves when the events have been persisted
+ Resolves when the events have been persisted
"""
# We want to calculate the stream orderings as late as possible, as
@@ -189,11 +153,11 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
if backfilled:
- stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+ stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
- stream_ordering_manager = self._stream_id_gen.get_next_mult(
+ stream_ordering_manager = await self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
@@ -201,12 +165,11 @@ class PersistEventsStore:
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
backfilled=backfilled,
- delete_existing=delete_existing,
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
@@ -232,24 +195,23 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
- for room_id, new_state in iteritems(current_state_for_room):
+ for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state)
- for room_id, latest_event_ids in iteritems(new_forward_extremeties):
+ for room_id, latest_event_ids in new_forward_extremeties.items():
self.store.get_latest_event_ids_in_room.prefill(
(room_id,), list(latest_event_ids)
)
- @defer.inlineCallbacks
- def _get_events_which_are_prevs(self, event_ids):
+ async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
"""Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events.
Args:
- event_ids (Iterable[str]): event ids to filter
+ event_ids: event ids to filter
Returns:
- Deferred[List[str]]: filtered event ids
+ Filtered event ids
"""
results = []
@@ -271,17 +233,16 @@ class PersistEventsStore:
)
txn.execute(sql + clause, args)
- results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
+ results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
)
return results
- @defer.inlineCallbacks
- def _get_prevs_before_rejected(self, event_ids):
+ async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
"""Get soft-failed ancestors to remove from the extremities.
Given a set of events, find all those that have been soft-failed or
@@ -293,11 +254,11 @@ class PersistEventsStore:
are separated by soft failed events.
Args:
- event_ids (Iterable[str]): Events to find prev events for. Note
- that these must have already been persisted.
+ event_ids: Events to find prev events for. Note that these must have
+ already been persisted.
Returns:
- Deferred[set[str]]
+ The previous events.
"""
# The set of event_ids to return. This includes all soft-failed events
@@ -332,13 +293,13 @@ class PersistEventsStore:
if prev_event_id in existing_prevs:
continue
- soft_failed = json.loads(metadata).get("soft_failed")
+ soft_failed = db_to_json(metadata).get("soft_failed")
if soft_failed or rejected:
to_recursively_check.append(prev_event_id)
existing_prevs.add(prev_event_id)
for chunk in batch_iter(event_ids, 100):
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
)
@@ -350,7 +311,6 @@ class PersistEventsStore:
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
- delete_existing: bool = False,
state_delta_for_room: Dict[str, DeltaState] = {},
new_forward_extremeties: Dict[str, List[str]] = {},
):
@@ -402,13 +362,6 @@ class PersistEventsStore:
# From this point onwards the events are only events that we haven't
# seen before.
- if delete_existing:
- # For paranoia reasons, we go and delete all the existing entries
- # for these events so we can reinsert them.
- # This gets around any problems with some tables already having
- # entries.
- self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts)
-
self._store_event_txn(txn, events_and_contexts=events_and_contexts)
# Insert into event_to_state_groups.
@@ -420,7 +373,7 @@ class PersistEventsStore:
# event's auth chain, but its easier for now just to store them (and
# it doesn't take much storage compared to storing the entire event
# anyway).
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="event_auth",
values=[
@@ -461,7 +414,7 @@ class PersistEventsStore:
state_delta_by_room: Dict[str, DeltaState],
stream_id: int,
):
- for room_id, delta_state in iteritems(state_delta_by_room):
+ for room_id, delta_state in state_delta_by_room.items():
to_delete = delta_state.to_delete
to_insert = delta_state.to_insert
@@ -483,7 +436,7 @@ class PersistEventsStore:
"""
txn.execute(sql, (stream_id, room_id))
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="current_state_events", keyvalues={"room_id": room_id},
)
else:
@@ -545,7 +498,7 @@ class PersistEventsStore:
""",
[
(room_id, key[0], key[1], ev_id, ev_id)
- for key, ev_id in iteritems(to_insert)
+ for key, ev_id in to_insert.items()
],
)
@@ -626,12 +579,12 @@ class PersistEventsStore:
txn.execute(sql, (room_id, EventTypes.Create, ""))
row = txn.fetchone()
if row:
- event_json = json.loads(row[0])
+ event_json = db_to_json(row[0])
content = event_json.get("content", {})
creator = content.get("creator")
room_version_id = content.get("room_version", RoomVersions.V1.identifier)
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="rooms",
keyvalues={"room_id": room_id},
@@ -642,20 +595,20 @@ class PersistEventsStore:
def _update_forward_extremities_txn(
self, txn, new_forward_extremities, max_stream_order
):
- for room_id, new_extrem in iteritems(new_forward_extremities):
- self.db.simple_delete_txn(
+ for room_id, new_extrem in new_forward_extremities.items():
+ self.db_pool.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
)
txn.call_after(
self.store.get_latest_event_ids_in_room.invalidate, (room_id,)
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="event_forward_extremities",
values=[
{"event_id": ev_id, "room_id": room_id}
- for room_id, new_extrem in iteritems(new_forward_extremities)
+ for room_id, new_extrem in new_forward_extremities.items()
for ev_id in new_extrem
],
)
@@ -663,7 +616,7 @@ class PersistEventsStore:
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
@@ -672,7 +625,7 @@ class PersistEventsStore:
"event_id": event_id,
"stream_ordering": max_stream_order,
}
- for room_id, new_extrem in iteritems(new_forward_extremities)
+ for room_id, new_extrem in new_forward_extremities.items()
for event_id in new_extrem
],
)
@@ -727,7 +680,7 @@ class PersistEventsStore:
event.depth, depth_updates.get(event.room_id, event.depth)
)
- for room_id, depth in iteritems(depth_updates):
+ for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth)
def _update_outliers_txn(self, txn, events_and_contexts):
@@ -787,7 +740,7 @@ class PersistEventsStore:
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="ex_outlier_stream",
values={
@@ -806,40 +759,6 @@ class PersistEventsStore:
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
- @classmethod
- def _delete_existing_rows_txn(cls, txn, events_and_contexts):
- if not events_and_contexts:
- # nothing to do here
- return
-
- logger.info("Deleting existing")
-
- for table in (
- "events",
- "event_auth",
- "event_json",
- "event_edges",
- "event_forward_extremities",
- "event_reference_hashes",
- "event_search",
- "event_to_state_groups",
- "local_invites",
- "state_events",
- "rejections",
- "redactions",
- "room_memberships",
- ):
- txn.executemany(
- "DELETE FROM %s WHERE event_id = ?" % (table,),
- [(ev.event_id,) for ev, _ in events_and_contexts],
- )
-
- for table in ("event_push_actions",):
- txn.executemany(
- "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
- [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts],
- )
-
def _store_event_txn(self, txn, events_and_contexts):
"""Insert new events into the event and event_json tables
@@ -859,7 +778,7 @@ class PersistEventsStore:
d.pop("redacted_because", None)
return d
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="event_json",
values=[
@@ -876,7 +795,7 @@ class PersistEventsStore:
],
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="events",
values=[
@@ -893,8 +812,7 @@ class PersistEventsStore:
"received_ts": self._clock.time_msec(),
"sender": event.sender,
"contains_url": (
- "url" in event.content
- and isinstance(event.content["url"], text_type)
+ "url" in event.content and isinstance(event.content["url"], str)
),
}
for event, _ in events_and_contexts
@@ -906,7 +824,7 @@ class PersistEventsStore:
# If we're persisting an unredacted event we go and ensure
# that we mark any redactions that reference this event as
# requiring censoring.
- self.db.simple_update_txn(
+ self.db_pool.simple_update_txn(
txn,
table="redactions",
keyvalues={"redacts": event.event_id},
@@ -1048,7 +966,9 @@ class PersistEventsStore:
state_values.append(vals)
- self.db.simple_insert_many_txn(txn, table="state_events", values=state_values)
+ self.db_pool.simple_insert_many_txn(
+ txn, table="state_events", values=state_values
+ )
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
@@ -1079,7 +999,7 @@ class PersistEventsStore:
)
txn.execute(sql + clause, args)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
@@ -1099,7 +1019,7 @@ class PersistEventsStore:
# invalidate the cache for the redacted event
txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="redactions",
values={
@@ -1122,7 +1042,7 @@ class PersistEventsStore:
room_id (str): The ID of the room the event was sent to.
topological_ordering (int): The position of the event in the room's topology.
"""
- return self.db.simple_insert_many_txn(
+ return self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_labels",
values=[
@@ -1144,7 +1064,7 @@ class PersistEventsStore:
event_id (str): The event ID the expiry timestamp is associated with.
expiry_ts (int): The timestamp at which to expire (delete) the event.
"""
- return self.db.simple_insert_txn(
+ return self.db_pool.simple_insert_txn(
txn=txn,
table="event_expiry",
values={"event_id": event_id, "expiry_ts": expiry_ts},
@@ -1168,12 +1088,14 @@ class PersistEventsStore:
}
)
- self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
+ self.db_pool.simple_insert_many_txn(
+ txn, table="event_reference_hashes", values=vals
+ )
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="room_memberships",
values=[
@@ -1201,65 +1123,27 @@ class PersistEventsStore:
(event.state_key,),
)
- # We update the local_invites table only if the event is "current",
- # i.e., its something that has just happened. If the event is an
- # outlier it is only current if its an "out of band membership",
- # like a remote invite or a rejection of a remote invite.
- is_new_state = not backfilled and (
- not event.internal_metadata.is_outlier()
- or event.internal_metadata.is_out_of_band_membership()
- )
- is_mine = self.is_mine_id(event.state_key)
- if is_new_state and is_mine:
- if event.membership == Membership.INVITE:
- self.db.simple_insert_txn(
- txn,
- table="local_invites",
- values={
- "event_id": event.event_id,
- "invitee": event.state_key,
- "inviter": event.sender,
- "room_id": event.room_id,
- "stream_id": event.internal_metadata.stream_ordering,
- },
- )
- else:
- sql = (
- "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- txn.execute(
- sql,
- (
- event.internal_metadata.stream_ordering,
- event.event_id,
- event.room_id,
- event.state_key,
- ),
- )
-
- # We also update the `local_current_membership` table with
- # latest invite info. This will usually get updated by the
- # `current_state_events` handling, unless its an outlier.
- if event.internal_metadata.is_outlier():
- # This should only happen for out of band memberships, so
- # we add a paranoia check.
- assert event.internal_metadata.is_out_of_band_membership()
-
- self.db.simple_upsert_txn(
- txn,
- table="local_current_membership",
- keyvalues={
- "room_id": event.room_id,
- "user_id": event.state_key,
- },
- values={
- "event_id": event.event_id,
- "membership": event.membership,
- },
- )
+ # We update the local_current_membership table only if the event is
+ # "current", i.e., its something that has just happened.
+ #
+ # This will usually get updated by the `current_state_events` handling,
+ # unless its an outlier, and an outlier is only "current" if it's an "out of
+ # band membership", like a remote invite or a rejection of a remote invite.
+ if (
+ self.is_mine_id(event.state_key)
+ and not backfilled
+ and event.internal_metadata.is_outlier()
+ and event.internal_metadata.is_out_of_band_membership()
+ ):
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="local_current_membership",
+ keyvalues={"room_id": event.room_id, "user_id": event.state_key},
+ values={
+ "event_id": event.event_id,
+ "membership": event.membership,
+ },
+ )
def _handle_event_relations(self, txn, event):
"""Handles inserting relation data during peristence of events
@@ -1289,7 +1173,7 @@ class PersistEventsStore:
aggregation_key = relation.get("key")
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="event_relations",
values={
@@ -1317,7 +1201,7 @@ class PersistEventsStore:
redacted_event_id (str): The event that was redacted.
"""
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
@@ -1345,15 +1229,15 @@ class PersistEventsStore:
):
if (
"min_lifetime" in event.content
- and not isinstance(event.content.get("min_lifetime"), integer_types)
+ and not isinstance(event.content.get("min_lifetime"), int)
) or (
"max_lifetime" in event.content
- and not isinstance(event.content.get("max_lifetime"), integer_types)
+ and not isinstance(event.content.get("max_lifetime"), int)
):
# Ignore the event if one of the value isn't an integer.
return
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn=txn,
table="room_retention",
values={
@@ -1412,9 +1296,9 @@ class PersistEventsStore:
sql = """
INSERT INTO event_push_actions (
room_id, event_id, user_id, actions, stream_ordering,
- topological_ordering, notif, highlight
+ topological_ordering, notif, highlight, unread
)
- SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+ SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
FROM event_push_actions_staging
WHERE event_id = ?
"""
@@ -1434,7 +1318,7 @@ class PersistEventsStore:
)
for event, _ in events_and_contexts:
- user_ids = self.db.simple_select_onecol_txn(
+ user_ids = self.db_pool.simple_select_onecol_txn(
txn,
table="event_push_actions_staging",
keyvalues={"event_id": event.event_id},
@@ -1466,7 +1350,7 @@ class PersistEventsStore:
)
def _store_rejections_txn(self, txn, event_id, reason):
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="rejections",
values={
@@ -1492,16 +1376,16 @@ class PersistEventsStore:
state_groups[event.event_id] = context.state_group
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{"state_group": state_group_id, "event_id": event_id}
- for event_id, state_group_id in iteritems(state_groups)
+ for event_id, state_group_id in state_groups.items()
],
)
- for event_id, state_group_id in iteritems(state_groups):
+ for event_id, state_group_id in state_groups.items():
txn.call_after(
self.store._get_state_group_for_event.prefill,
(event_id,),
@@ -1514,7 +1398,7 @@ class PersistEventsStore:
if min_depth is not None and depth >= min_depth:
return
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
@@ -1526,7 +1410,7 @@ class PersistEventsStore:
For the given event, update the event edges table and forward and
backward extremities tables.
"""
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="event_edges",
values=[
@@ -1590,31 +1474,3 @@ class PersistEventsStore:
if not ev.internal_metadata.is_outlier()
],
)
-
- async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
- """Mark the invite has having been rejected even though we failed to
- create a leave event for it.
- """
-
- sql = (
- "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- def f(txn, stream_ordering):
- txn.execute(sql, (stream_ordering, True, room_id, user_id))
-
- # We also clear this entry from `local_current_membership`.
- # Ideally we'd point to a leave event, but we don't have one, so
- # nevermind.
- self.db.simple_delete_txn(
- txn,
- table="local_current_membership",
- keyvalues={"room_id": room_id, "user_id": user_id},
- )
-
- with self._stream_id_gen.get_next() as stream_ordering:
- await self.db.runInteraction("locally_reject_invite", f, stream_ordering)
-
- return stream_ordering
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index f54c8b1ee0..e53c6373a8 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -15,15 +15,9 @@
import logging
-from six import text_type
-
-from canonicaljson import json
-
-from twisted.internet import defer
-
from synapse.api.constants import EventContentFields
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.database import DatabasePool
logger = logging.getLogger(__name__)
@@ -34,18 +28,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
self._background_reindex_fields_sender,
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"event_contains_url_index",
index_name="event_contains_url_index",
table="events",
@@ -56,7 +50,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# an event_id index on event_search is useful for the purge_history
# api. Plus it means we get to enforce some integrity with a UNIQUE
# clause
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"event_search_event_id_idx",
index_name="event_search_event_id_idx",
table="event_search",
@@ -65,16 +59,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
psql_only=True,
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"redactions_received_ts", self._redactions_received_ts
)
# This index gets deleted in `event_fix_redactions_bytes` update
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"event_fix_redactions_bytes_create_index",
index_name="redactions_censored_redacts",
table="redactions",
@@ -82,15 +76,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
where_clause="have_censored",
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"event_fix_redactions_bytes", self._event_fix_redactions_bytes
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"event_store_labels", self._event_store_labels
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"redactions_have_censored_ts_idx",
index_name="redactions_have_censored_ts",
table="redactions",
@@ -98,8 +92,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
where_clause="NOT have_censored",
)
- @defer.inlineCallbacks
- def _background_reindex_fields_sender(self, progress, batch_size):
+ async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -127,13 +120,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
for row in rows:
try:
event_id = row[1]
- event_json = json.loads(row[2])
+ event_json = db_to_json(row[2])
sender = event_json["sender"]
content = event_json["content"]
contains_url = "url" in content
if contains_url:
- contains_url &= isinstance(content["url"], text_type)
+ contains_url &= isinstance(content["url"], str)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
@@ -153,25 +146,24 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
"rows_inserted": rows_inserted + len(rows),
}
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
)
return len(rows)
- result = yield self.db.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
)
return result
- @defer.inlineCallbacks
- def _background_reindex_origin_server_ts(self, progress, batch_size):
+ async def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -199,7 +191,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
for chunk in chunks:
- ev_rows = self.db.simple_select_many_txn(
+ ev_rows = self.db_pool.simple_select_many_txn(
txn,
table="event_json",
column="event_id",
@@ -210,7 +202,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
for row in ev_rows:
event_id = row["event_id"]
- event_json = json.loads(row["json"])
+ event_json = db_to_json(row["json"])
try:
origin_server_ts = event_json["origin_server_ts"]
except (KeyError, AttributeError):
@@ -232,25 +224,24 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
"rows_inserted": rows_inserted + len(rows_to_update),
}
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
)
return len(rows_to_update)
- result = yield self.db.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_ORIGIN_SERVER_TS_NAME
)
return result
- @defer.inlineCallbacks
- def _cleanup_extremities_bg_update(self, progress, batch_size):
+ async def _cleanup_extremities_bg_update(self, progress, batch_size):
"""Background update to clean out extremities that should have been
deleted previously.
@@ -319,7 +310,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
soft_failed = False
if metadata:
- soft_failed = json.loads(metadata).get("soft_failed")
+ soft_failed = db_to_json(metadata).get("soft_failed")
if soft_failed or rejected:
soft_failed_events_to_lookup.add(event_id)
@@ -360,7 +351,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
graph[event_id] = {prev_event_id}
- soft_failed = json.loads(metadata).get("soft_failed")
+ soft_failed = db_to_json(metadata).get("soft_failed")
if soft_failed or rejected:
soft_failed_events_to_lookup.add(event_id)
else:
@@ -378,7 +369,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
to_delete.intersection_update(original_set)
- deleted = self.db.simple_delete_many_txn(
+ deleted = self.db_pool.simple_delete_many_txn(
txn=txn,
table="event_forward_extremities",
column="event_id",
@@ -394,7 +385,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if deleted:
# We now need to invalidate the caches of these rooms
- rows = self.db.simple_select_many_txn(
+ rows = self.db_pool.simple_select_many_txn(
txn,
table="events",
column="event_id",
@@ -408,7 +399,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn=txn,
table="_extremities_to_check",
column="event_id",
@@ -418,26 +409,25 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(original_set)
- num_handled = yield self.db.runInteraction(
+ num_handled = await self.db_pool.runInteraction(
"_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
)
if not num_handled:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.DELETE_SOFT_FAILED_EXTREMITIES
)
def _drop_table_txn(txn):
txn.execute("DROP TABLE _extremities_to_check")
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn
)
return num_handled
- @defer.inlineCallbacks
- def _redactions_received_ts(self, progress, batch_size):
+ async def _redactions_received_ts(self, progress, batch_size):
"""Handles filling out the `received_ts` column in redactions.
"""
last_event_id = progress.get("last_event_id", "")
@@ -478,23 +468,22 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id))
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, "redactions_received_ts", {"last_event_id": upper_event_id}
)
return len(rows)
- count = yield self.db.runInteraction(
+ count = await self.db_pool.runInteraction(
"_redactions_received_ts", _redactions_received_ts_txn
)
if not count:
- yield self.db.updates._end_background_update("redactions_received_ts")
+ await self.db_pool.updates._end_background_update("redactions_received_ts")
return count
- @defer.inlineCallbacks
- def _event_fix_redactions_bytes(self, progress, batch_size):
+ async def _event_fix_redactions_bytes(self, progress, batch_size):
"""Undoes hex encoded censored redacted event JSON.
"""
@@ -515,16 +504,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn.execute("DROP INDEX redactions_censored_redacts")
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
)
- yield self.db.updates._end_background_update("event_fix_redactions_bytes")
+ await self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
return 1
- @defer.inlineCallbacks
- def _event_store_labels(self, progress, batch_size):
+ async def _event_store_labels(self, progress, batch_size):
"""Background update handler which will store labels for existing events."""
last_event_id = progress.get("last_event_id", "")
@@ -545,9 +533,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
last_row_event_id = ""
for (event_id, event_json_raw) in results:
try:
- event_json = json.loads(event_json_raw)
+ event_json = db_to_json(event_json_raw)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_labels",
values=[
@@ -573,17 +561,17 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
nbrows += 1
last_row_event_id = event_id
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, "event_store_labels", {"last_event_id": last_row_event_id}
)
return nbrows
- num_rows = yield self.db.runInteraction(
+ num_rows = await self.db_pool.runInteraction(
desc="event_store_labels", func=_event_store_labels_txn
)
if not num_rows:
- yield self.db.updates._end_background_update("event_store_labels")
+ await self.db_pool.updates._end_background_update("event_store_labels")
return num_rows
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 213d69100a..a7a73cc3d8 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -19,10 +19,10 @@ import itertools
import logging
import threading
from collections import namedtuple
-from typing import List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple, overload
-from canonicaljson import json
from constantly import NamedConstant, Names
+from typing_extensions import Literal
from twisted.internet import defer
@@ -33,16 +33,18 @@ from synapse.api.room_versions import (
EventFormatVersions,
RoomVersions,
)
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import BackfillStream
+from synapse.replication.tcp.streams.events import EventsStream
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
+from synapse.types import Collection, get_domain_from_id
+from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -73,17 +75,14 @@ class EventRedactBehaviour(Names):
class EventsWorkerStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.writers.events == hs.get_instance_name():
# We are the process in charge of generating stream ids for events,
# so instantiate ID generators based on the database
self._stream_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- extra_tables=[("local_invites", "stream_id")],
+ db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
@@ -113,70 +112,59 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_ongoing = 0
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "events":
- self._stream_id_gen.advance(token)
- elif stream_name == "backfill":
- self._backfill_id_gen.advance(-token)
+ if stream_name == EventsStream.NAME:
+ self._stream_id_gen.advance(instance_name, token)
+ elif stream_name == BackfillStream.NAME:
+ self._backfill_id_gen.advance(instance_name, -token)
super().process_replication_rows(stream_name, instance_name, token, rows)
- def get_received_ts(self, event_id):
+ async def get_received_ts(self, event_id: str) -> Optional[int]:
"""Get received_ts (when it was persisted) for the event.
Raises an exception for unknown events.
Args:
- event_id (str)
+ event_id: The event ID to query.
Returns:
- Deferred[int|None]: Timestamp in milliseconds, or None for events
- that were persisted before received_ts was implemented.
+ Timestamp in milliseconds, or None for events that were persisted
+ before received_ts was implemented.
"""
- return self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="received_ts",
desc="get_received_ts",
)
- def get_received_ts_by_stream_pos(self, stream_ordering):
- """Given a stream ordering get an approximate timestamp of when it
- happened.
-
- This is done by simply taking the received ts of the first event that
- has a stream ordering greater than or equal to the given stream pos.
- If none exists returns the current time, on the assumption that it must
- have happened recently.
-
- Args:
- stream_ordering (int)
-
- Returns:
- Deferred[int]
- """
-
- def _get_approximate_received_ts_txn(txn):
- sql = """
- SELECT received_ts FROM events
- WHERE stream_ordering >= ?
- LIMIT 1
- """
-
- txn.execute(sql, (stream_ordering,))
- row = txn.fetchone()
- if row and row[0]:
- ts = row[0]
- else:
- ts = self.clock.time_msec()
-
- return ts
+ # Inform mypy that if allow_none is False (the default) then get_event
+ # always returns an EventBase.
+ @overload
+ async def get_event(
+ self,
+ event_id: str,
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: Literal[False] = False,
+ check_room_id: Optional[str] = None,
+ ) -> EventBase:
+ ...
- return self.db.runInteraction(
- "get_approximate_received_ts", _get_approximate_received_ts_txn
- )
+ @overload
+ async def get_event(
+ self,
+ event_id: str,
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: Literal[True] = False,
+ check_room_id: Optional[str] = None,
+ ) -> Optional[EventBase]:
+ ...
- @defer.inlineCallbacks
- def get_event(
+ async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
@@ -184,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected: bool = False,
allow_none: bool = False,
check_room_id: Optional[str] = None,
- ):
+ ) -> Optional[EventBase]:
"""Get an event from the database by event_id.
Args:
@@ -209,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none.
Returns:
- Deferred[EventBase|None]
+ The event, or None if the event was not found.
"""
if not isinstance(event_id, str):
raise TypeError("Invalid event event_id %r" % (event_id,))
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[event_id],
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@@ -232,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
return event
- @defer.inlineCallbacks
- def get_events(
+ async def get_events(
self,
- event_ids: List[str],
+ event_ids: Iterable[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
- ):
+ ) -> Dict[str, EventBase]:
"""Get events from the database
Args:
@@ -258,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
omits rejeted events from the response.
Returns:
- Deferred : Dict from event_id to event.
+ A mapping from event_id to event.
"""
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
event_ids,
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@@ -269,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event_id: e for e in events}
- @defer.inlineCallbacks
- def get_events_as_list(
+ async def get_events_as_list(
self,
- event_ids: List[str],
+ event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
- ):
+ ) -> List[EventBase]:
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
@@ -297,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
omits rejected events from the response.
Returns:
- Deferred[list[EventBase]]: List of events fetched from the database. The
- events are in the same order as `event_ids` arg.
+ List of events fetched from the database. The events are in the same
+ order as `event_ids` arg.
Note that the returned list may be smaller than the list of event
IDs if not all events could be fetched.
@@ -308,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
return []
# there may be duplicates so we cast the list to a set
- event_entry_map = yield self._get_events_from_cache_or_db(
+ event_entry_map = await self._get_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected
)
@@ -343,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
continue
redacted_event_id = entry.event.redacts
- event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
+ event_map = await self._get_events_from_cache_or_db([redacted_event_id])
original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry:
# we don't have the redacted event (or it was rejected).
@@ -409,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
if get_prev_content:
if "replaces_state" in event.unsigned:
- prev = yield self.get_event(
+ prev = await self.get_event(
event.unsigned["replaces_state"],
get_prev_content=False,
allow_none=True,
@@ -421,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
return events
- @defer.inlineCallbacks
- def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -437,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
- Deferred[Dict[str, _EventCacheEntry]]:
+ Dict[str, _EventCacheEntry]:
map from event id to result
"""
event_entry_map = self._get_events_from_cache(
@@ -455,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
- missing_events = yield self._get_events_from_db(
+ missing_events = await self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected
)
@@ -539,7 +524,7 @@ class EventsWorkerStore(SQLBaseStore):
event_id for events, _ in event_list for event_id in events
}
- row_dict = self.db.new_transaction(
+ row_dict = self.db_pool.new_transaction(
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
)
@@ -563,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
- @defer.inlineCallbacks
- def _get_events_from_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the database.
Returned events will be added to the cache for future lookups.
@@ -578,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
- Deferred[Dict[str, _EventCacheEntry]]:
+ Dict[str, _EventCacheEntry]:
map from event id to result. May return extra events which
weren't asked for.
"""
@@ -586,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
events_to_fetch = event_ids
while events_to_fetch:
- row_map = yield self._enqueue_events(events_to_fetch)
+ row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids = set()
@@ -612,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected and rejected_reason:
continue
- d = json.loads(row["json"])
- internal_metadata = json.loads(row["internal_metadata"])
+ # If the event or metadata cannot be parsed, log the error and act
+ # as if the event is unknown.
+ try:
+ d = db_to_json(row["json"])
+ except ValueError:
+ logger.error("Unable to parse json from event: %s", event_id)
+ continue
+ try:
+ internal_metadata = db_to_json(row["internal_metadata"])
+ except ValueError:
+ logger.error(
+ "Unable to parse internal_metadata from event: %s", event_id
+ )
+ continue
format_version = row["format_version"]
if format_version is None:
@@ -624,24 +620,43 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row["room_version_id"]
if not room_version_id:
- # this should only happen for out-of-band membership events
- if not internal_metadata.get("out_of_band_membership"):
- logger.warning(
- "Room %s for event %s is unknown", d["room_id"], event_id
+ # this should only happen for out-of-band membership events which
+ # arrived before #6983 landed. For all other events, we should have
+ # an entry in the 'rooms' table.
+ #
+ # However, the 'out_of_band_membership' flag is unreliable for older
+ # invites, so just accept it for all membership events.
+ #
+ if d["type"] != EventTypes.Member:
+ raise Exception(
+ "Room %s for event %s is unknown" % (d["room_id"], event_id)
)
- continue
- # take a wild stab at the room version based on the event format
+ # so, assuming this is an out-of-band-invite that arrived before #6983
+ # landed, we know that the room version must be v5 or earlier (because
+ # v6 hadn't been invented at that point, so invites from such rooms
+ # would have been rejected.)
+ #
+ # The main reason we need to know the room version here (other than
+ # choosing the right python Event class) is in case the event later has
+ # to be redacted - and all the room versions up to v5 used the same
+ # redaction algorithm.
+ #
+ # So, the following approximations should be adequate.
+
if format_version == EventFormatVersions.V1:
+ # if it's event format v1 then it must be room v1 or v2
room_version = RoomVersions.V1
elif format_version == EventFormatVersions.V2:
+ # if it's event format v2 then it must be room v3
room_version = RoomVersions.V3
else:
+ # if it's event format v3 then it must be room v4 or v5
room_version = RoomVersions.V5
else:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not room_version:
- logger.error(
+ logger.warning(
"Event %s in room %s has unknown room version %s",
event_id,
d["room_id"],
@@ -688,8 +703,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- @defer.inlineCallbacks
- def _enqueue_events(self, events):
+ async def _enqueue_events(self, events):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -698,7 +712,7 @@ class EventsWorkerStore(SQLBaseStore):
events (Iterable[str]): events to be fetched.
Returns:
- Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+ Dict[str, Dict]: map from event id to row data from the database.
May contain events that weren't requested.
"""
@@ -716,12 +730,12 @@ class EventsWorkerStore(SQLBaseStore):
if should_start:
run_as_background_process(
- "fetch_events", self.db.runWithConnection, self._do_fetch
+ "fetch_events", self.db_pool.runWithConnection, self._do_fetch
)
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
- row_map = yield events_d
+ row_map = await events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
return row_map
@@ -809,20 +823,24 @@ class EventsWorkerStore(SQLBaseStore):
return event_dict
- def _maybe_redact_event_row(self, original_ev, redactions, event_map):
+ def _maybe_redact_event_row(
+ self,
+ original_ev: EventBase,
+ redactions: Iterable[str],
+ event_map: Dict[str, EventBase],
+ ) -> Optional[EventBase]:
"""Given an event object and a list of possible redacting event ids,
determine whether to honour any of those redactions and if so return a redacted
event.
Args:
- original_ev (EventBase):
- redactions (iterable[str]): list of event ids of potential redaction events
- event_map (dict[str, EventBase]): other events which have been fetched, in
- which we can look up the redaaction events. Map from event id to event.
+ original_ev: The original event.
+ redactions: list of event ids of potential redaction events
+ event_map: other events which have been fetched, in which we can
+ look up the redaaction events. Map from event id to event.
Returns:
- Deferred[EventBase|None]: if the event should be redacted, a pruned
- event object. Otherwise, None.
+ If the event should be redacted, a pruned event object. Otherwise, None.
"""
if original_ev.type == "m.room.create":
# we choose to ignore redactions of m.room.create events.
@@ -880,12 +898,11 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
- @defer.inlineCallbacks
- def have_events_in_timeline(self, event_ids):
+ async def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
@@ -896,15 +913,14 @@ class EventsWorkerStore(SQLBaseStore):
return {r["event_id"] for r in rows}
- @defer.inlineCallbacks
- def have_seen_events(self, event_ids):
+ async def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Args:
event_ids (iterable[str]):
Returns:
- Deferred[set[str]]: The events we have already seen.
+ set[str]: The events we have already seen.
"""
results = set()
@@ -920,41 +936,11 @@ class EventsWorkerStore(SQLBaseStore):
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
return results
- def _get_total_state_event_counts_txn(self, txn, room_id):
- """
- See get_total_state_event_counts.
- """
- # We join against the events table as that has an index on room_id
- sql = """
- SELECT COUNT(*) FROM state_events
- INNER JOIN events USING (room_id, event_id)
- WHERE room_id=?
- """
- txn.execute(sql, (room_id,))
- row = txn.fetchone()
- return row[0] if row else 0
-
- def get_total_state_event_counts(self, room_id):
- """
- Gets the total number of state events in a room.
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[int]
- """
- return self.db.runInteraction(
- "get_total_state_event_counts",
- self._get_total_state_event_counts_txn,
- room_id,
- )
-
def _get_current_state_event_counts_txn(self, txn, room_id):
"""
See get_current_state_event_counts.
@@ -964,24 +950,23 @@ class EventsWorkerStore(SQLBaseStore):
row = txn.fetchone()
return row[0] if row else 0
- def get_current_state_event_counts(self, room_id):
+ async def get_current_state_event_counts(self, room_id: str) -> int:
"""
Gets the current number of state events in a room.
Args:
- room_id (str)
+ room_id: The room ID to query.
Returns:
- Deferred[int]
+ The current number of state events.
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_current_state_event_counts",
self._get_current_state_event_counts_txn,
room_id,
)
- @defer.inlineCallbacks
- def get_room_complexity(self, room_id):
+ async def get_room_complexity(self, room_id):
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@@ -992,9 +977,9 @@ class EventsWorkerStore(SQLBaseStore):
room_id (str)
Returns:
- Deferred[dict[str:int]] of complexity version to complexity.
+ dict[str:int] of complexity version to complexity.
"""
- state_events = yield self.get_current_state_event_counts(room_id)
+ state_events = await self.get_current_state_event_counts(room_id)
# Call this one "v1", so we can introduce new ones as we want to develop
# it.
@@ -1010,7 +995,9 @@ class EventsWorkerStore(SQLBaseStore):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
- def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+ async def get_all_new_forward_event_rows(
+ self, last_id: int, current_id: int, limit: int
+ ) -> List[Tuple]:
"""Returns new events, for the Events replication stream
Args:
@@ -1018,7 +1005,7 @@ class EventsWorkerStore(SQLBaseStore):
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return
- Returns: Deferred[List[Tuple]]
+ Returns:
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
@@ -1039,18 +1026,20 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
- def get_ex_outlier_stream_rows(self, last_id, current_id):
+ async def get_ex_outlier_stream_rows(
+ self, last_id: int, current_id: int
+ ) -> List[Tuple]:
"""Returns de-outliered events, for the Events replication stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
- Returns: Deferred[List[Tuple]]
+ Returns:
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
@@ -1073,13 +1062,36 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id))
return txn.fetchall()
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
)
- def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
+ async def get_all_new_backfill_event_rows(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
+ """Get updates for backfill replication stream, including all new
+ backfilled events and events that have gone from being outliers to not.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
if last_id == current_id:
- return defer.succeed([])
+ return [], current_id, False
def get_all_new_backfill_event_rows(txn):
sql = (
@@ -1094,10 +1106,12 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, limit))
- new_event_updates = txn.fetchall()
+ new_event_updates = [(row[0], row[1:]) for row in txn]
+ limited = False
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
+ limited = True
else:
upper_bound = current_id
@@ -1114,11 +1128,15 @@ class EventsWorkerStore(SQLBaseStore):
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound))
- new_event_updates.extend(txn.fetchall())
+ new_event_updates.extend((row[0], row[1:]) for row in txn)
+
+ if len(new_event_updates) >= limit:
+ upper_bound = new_event_updates[-1][0]
+ limited = True
- return new_event_updates
+ return new_event_updates, upper_bound, limited
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
@@ -1166,7 +1184,7 @@ class EventsWorkerStore(SQLBaseStore):
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
- rows = await self.db.runInteraction(
+ rows = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
) # type: List[Tuple]
@@ -1189,103 +1207,12 @@ class EventsWorkerStore(SQLBaseStore):
# stream id. let's run the query again, without a row limit, but for
# just one stream id.
to_token += 1
- rows = await self.db.runInteraction(
+ rows = await self.db_pool.runInteraction(
"get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
)
return rows, to_token, True
- @cached(num_args=5, max_entries=10)
- def get_all_new_events(
- self,
- last_backfill_id,
- last_forward_id,
- current_backfill_id,
- current_forward_id,
- limit,
- ):
- """Get all the new events that have arrived at the server either as
- new events or as backfilled events"""
- have_backfill_events = last_backfill_id != current_backfill_id
- have_forward_events = last_forward_id != current_forward_id
-
- if not have_backfill_events and not have_forward_events:
- return defer.succeed(AllNewEventsResult([], [], [], [], []))
-
- def get_all_new_events_txn(txn):
- sql = (
- "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " WHERE ? < stream_ordering AND stream_ordering <= ?"
- " ORDER BY stream_ordering ASC"
- " LIMIT ?"
- )
- if have_forward_events:
- txn.execute(sql, (last_forward_id, current_forward_id, limit))
- new_forward_events = txn.fetchall()
-
- if len(new_forward_events) == limit:
- upper_bound = new_forward_events[-1][0]
- else:
- upper_bound = current_forward_id
-
- sql = (
- "SELECT event_stream_ordering, event_id, state_group"
- " FROM ex_outlier_stream"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (last_forward_id, upper_bound))
- forward_ex_outliers = txn.fetchall()
- else:
- new_forward_events = []
- forward_ex_outliers = []
-
- sql = (
- "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " WHERE ? > stream_ordering AND stream_ordering >= ?"
- " ORDER BY stream_ordering DESC"
- " LIMIT ?"
- )
- if have_backfill_events:
- txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
- new_backfill_events = txn.fetchall()
-
- if len(new_backfill_events) == limit:
- upper_bound = new_backfill_events[-1][0]
- else:
- upper_bound = current_backfill_id
-
- sql = (
- "SELECT -event_stream_ordering, event_id, state_group"
- " FROM ex_outlier_stream"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (-last_backfill_id, -upper_bound))
- backward_ex_outliers = txn.fetchall()
- else:
- new_backfill_events = []
- backward_ex_outliers = []
-
- return AllNewEventsResult(
- new_forward_events,
- new_backfill_events,
- forward_ex_outliers,
- backward_ex_outliers,
- )
-
- return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
-
async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
@@ -1293,9 +1220,9 @@ class EventsWorkerStore(SQLBaseStore):
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
- @cachedInlineCallbacks(max_entries=5000)
- def get_event_ordering(self, event_id):
- res = yield self.db.simple_select_one(
+ @cached(max_entries=5000)
+ async def get_event_ordering(self, event_id):
+ res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
@@ -1307,11 +1234,11 @@ class EventsWorkerStore(SQLBaseStore):
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
- def get_next_event_to_expire(self):
+ async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
table, or None if there's no more event to expire.
- Returns: Deferred[Optional[Tuple[str, int]]]
+ Returns:
A tuple containing the event ID as its first element and an expiry timestamp
as its second one, if there's at least one row in the event_expiry table.
None otherwise.
@@ -1327,17 +1254,6 @@ class EventsWorkerStore(SQLBaseStore):
return txn.fetchone()
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
-
-
-AllNewEventsResult = namedtuple(
- "AllNewEventsResult",
- [
- "new_forward_events",
- "new_backfill_events",
- "forward_ex_outliers",
- "backward_ex_outliers",
- ],
-)
diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 342d6622a4..d2f5b9a502 100644
--- a/synapse/storage/data_stores/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,12 +17,13 @@ from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.types import JsonDict
+from synapse.util.caches.descriptors import cached
class FilteringStore(SQLBaseStore):
- @cachedInlineCallbacks(num_args=2)
- def get_user_filter(self, user_localpart, filter_id):
+ @cached(num_args=2)
+ async def get_user_filter(self, user_localpart, filter_id):
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
try:
@@ -30,7 +31,7 @@ class FilteringStore(SQLBaseStore):
except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
- def_json = yield self.db.simple_select_one_onecol(
+ def_json = await self.db_pool.simple_select_one_onecol(
table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
@@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore):
return db_to_json(def_json)
- def add_user_filter(self, user_localpart, user_filter):
+ async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
@@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore):
return filter_id
- return self.db.runInteraction("add_user_filter", _do_txn)
+ return await self.db_pool.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 787ce9e584..1cbf31f52d 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -14,12 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from canonicaljson import json
-
-from twisted.internet import defer
+from typing import Any, Dict, List, Optional, Tuple, Union
from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.types import JsonDict
+from synapse.util import json_encoder
# The category ID for the "default" category. We don't store as null in the
# database to avoid the fun of null != null
@@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = ""
class GroupServerWorkerStore(SQLBaseStore):
- def get_group(self, group_id):
- return self.db.simple_select_one(
+ async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="groups",
keyvalues={"group_id": group_id},
retcols=(
@@ -44,31 +44,35 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_group",
)
- def get_users_in_group(self, group_id, include_private=False):
+ async def get_users_in_group(
+ self, group_id: str, include_private: bool = False
+ ) -> List[Dict[str, Any]]:
# TODO: Pagination
keyvalues = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
- return self.db.simple_select_list(
+ return await self.db_pool.simple_select_list(
table="group_users",
keyvalues=keyvalues,
retcols=("user_id", "is_public", "is_admin"),
desc="get_users_in_group",
)
- def get_invited_users_in_group(self, group_id):
+ async def get_invited_users_in_group(self, group_id: str) -> List[str]:
# TODO: Pagination
- return self.db.simple_select_onecol(
+ return await self.db_pool.simple_select_onecol(
table="group_invites",
keyvalues={"group_id": group_id},
retcol="user_id",
desc="get_invited_users_in_group",
)
- def get_rooms_in_group(self, group_id: str, include_private: bool = False):
+ async def get_rooms_in_group(
+ self, group_id: str, include_private: bool = False
+ ) -> List[Dict[str, Union[str, bool]]]:
"""Retrieve the rooms that belong to a given group. Does not return rooms that
lack members.
@@ -77,8 +81,7 @@ class GroupServerWorkerStore(SQLBaseStore):
include_private: Whether to return private rooms in results
Returns:
- Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
- form of:
+ A list of dictionaries, each in the form of:
{
"room_id": "!a_room_id:example.com", # The ID of the room
@@ -115,11 +118,13 @@ class GroupServerWorkerStore(SQLBaseStore):
for room_id, is_public in txn
]
- return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn)
+ return await self.db_pool.runInteraction(
+ "get_rooms_in_group", _get_rooms_in_group_txn
+ )
- def get_rooms_for_summary_by_category(
+ async def get_rooms_for_summary_by_category(
self, group_id: str, include_private: bool = False,
- ):
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
"""Get the rooms and categories that should be included in a summary request
Args:
@@ -127,7 +132,7 @@ class GroupServerWorkerStore(SQLBaseStore):
include_private: Whether to return private rooms in results
Returns:
- Deferred[Tuple[List, Dict]]: A tuple containing:
+ A tuple containing:
* A list of dictionaries with the keys:
* "room_id": str, the room ID
@@ -195,7 +200,7 @@ class GroupServerWorkerStore(SQLBaseStore):
categories = {
row[0]: {
"is_public": row[1],
- "profile": json.loads(row[2]),
+ "profile": db_to_json(row[2]),
"order": row[3],
}
for row in txn
@@ -203,13 +208,12 @@ class GroupServerWorkerStore(SQLBaseStore):
return rooms, categories
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_rooms_for_summary", _get_rooms_for_summary_txn
)
- @defer.inlineCallbacks
- def get_group_categories(self, group_id):
- rows = yield self.db.simple_select_list(
+ async def get_group_categories(self, group_id):
+ rows = await self.db_pool.simple_select_list(
table="group_room_categories",
keyvalues={"group_id": group_id},
retcols=("category_id", "is_public", "profile"),
@@ -219,27 +223,25 @@ class GroupServerWorkerStore(SQLBaseStore):
return {
row["category_id"]: {
"is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
+ "profile": db_to_json(row["profile"]),
}
for row in rows
}
- @defer.inlineCallbacks
- def get_group_category(self, group_id, category_id):
- category = yield self.db.simple_select_one(
+ async def get_group_category(self, group_id, category_id):
+ category = await self.db_pool.simple_select_one(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
retcols=("is_public", "profile"),
desc="get_group_category",
)
- category["profile"] = json.loads(category["profile"])
+ category["profile"] = db_to_json(category["profile"])
return category
- @defer.inlineCallbacks
- def get_group_roles(self, group_id):
- rows = yield self.db.simple_select_list(
+ async def get_group_roles(self, group_id):
+ rows = await self.db_pool.simple_select_list(
table="group_roles",
keyvalues={"group_id": group_id},
retcols=("role_id", "is_public", "profile"),
@@ -249,43 +251,42 @@ class GroupServerWorkerStore(SQLBaseStore):
return {
row["role_id"]: {
"is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
+ "profile": db_to_json(row["profile"]),
}
for row in rows
}
- @defer.inlineCallbacks
- def get_group_role(self, group_id, role_id):
- role = yield self.db.simple_select_one(
+ async def get_group_role(self, group_id, role_id):
+ role = await self.db_pool.simple_select_one(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
retcols=("is_public", "profile"),
desc="get_group_role",
)
- role["profile"] = json.loads(role["profile"])
+ role["profile"] = db_to_json(role["profile"])
return role
- def get_local_groups_for_room(self, room_id):
+ async def get_local_groups_for_room(self, room_id: str) -> List[str]:
"""Get all of the local group that contain a given room
Args:
- room_id (str): The ID of a room
+ room_id: The ID of a room
Returns:
- Deferred[list[str]]: A twisted.Deferred containing a list of group ids
- containing this room
+ A list of group ids containing this room
"""
- return self.db.simple_select_onecol(
+ return await self.db_pool.simple_select_onecol(
table="group_rooms",
keyvalues={"room_id": room_id},
retcol="group_id",
desc="get_local_groups_for_room",
)
- def get_users_for_summary_by_role(self, group_id, include_private=False):
+ async def get_users_for_summary_by_role(self, group_id, include_private=False):
"""Get the users and roles that should be included in a summary request
- Returns ([users], [roles])
+ Returns:
+ ([users], [roles])
"""
def _get_users_for_summary_txn(txn):
@@ -331,7 +332,7 @@ class GroupServerWorkerStore(SQLBaseStore):
roles = {
row[0]: {
"is_public": row[1],
- "profile": json.loads(row[2]),
+ "profile": db_to_json(row[2]),
"order": row[3],
}
for row in txn
@@ -339,21 +340,24 @@ class GroupServerWorkerStore(SQLBaseStore):
return users, roles
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_for_summary_by_role", _get_users_for_summary_txn
)
- def is_user_in_group(self, user_id, group_id):
- return self.db.simple_select_one_onecol(
+ async def is_user_in_group(self, user_id: str, group_id: str) -> bool:
+ result = await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
allow_none=True,
desc="is_user_in_group",
- ).addCallback(lambda r: bool(r))
+ )
+ return bool(result)
- def is_user_admin_in_group(self, group_id, user_id):
- return self.db.simple_select_one_onecol(
+ async def is_user_admin_in_group(
+ self, group_id: str, user_id: str
+ ) -> Optional[bool]:
+ return await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin",
@@ -361,10 +365,12 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="is_user_admin_in_group",
)
- def is_user_invited_to_local_group(self, group_id, user_id):
+ async def is_user_invited_to_local_group(
+ self, group_id: str, user_id: str
+ ) -> Optional[bool]:
"""Has the group server invited a user?
"""
- return self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
@@ -372,7 +378,7 @@ class GroupServerWorkerStore(SQLBaseStore):
allow_none=True,
)
- def get_users_membership_info_in_group(self, group_id, user_id):
+ async def get_users_membership_info_in_group(self, group_id, user_id):
"""Get a dict describing the membership of a user in a group.
Example if joined:
@@ -383,11 +389,12 @@ class GroupServerWorkerStore(SQLBaseStore):
"is_privileged": False,
}
- Returns an empty dict if the user is not join/invite/etc
+ Returns:
+ An empty dict if the user is not join/invite/etc
"""
def _get_users_membership_in_group_txn(txn):
- row = self.db.simple_select_one_txn(
+ row = self.db_pool.simple_select_one_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -402,7 +409,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"is_privileged": row["is_admin"],
}
- row = self.db.simple_select_one_onecol_txn(
+ row = self.db_pool.simple_select_one_onecol_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -415,21 +422,21 @@ class GroupServerWorkerStore(SQLBaseStore):
return {}
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_membership_info_in_group", _get_users_membership_in_group_txn
)
- def get_publicised_groups_for_user(self, user_id):
+ async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
"""Get all groups a user is publicising
"""
- return self.db.simple_select_onecol(
+ return await self.db_pool.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
retcol="group_id",
desc="get_publicised_groups_for_user",
)
- def get_attestations_need_renewals(self, valid_until_ms):
+ async def get_attestations_need_renewals(self, valid_until_ms):
"""Get all attestations that need to be renewed until givent time
"""
@@ -439,18 +446,17 @@ class GroupServerWorkerStore(SQLBaseStore):
WHERE valid_until_ms <= ?
"""
txn.execute(sql, (valid_until_ms,))
- return self.db.cursor_to_dict(txn)
+ return self.db_pool.cursor_to_dict(txn)
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)
- @defer.inlineCallbacks
- def get_remote_attestation(self, group_id, user_id):
+ async def get_remote_attestation(self, group_id, user_id):
"""Get the attestation that proves the remote agrees that the user is
in the group.
"""
- row = yield self.db.simple_select_one(
+ row = await self.db_pool.simple_select_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
retcols=("valid_until_ms", "attestation_json"),
@@ -460,19 +466,19 @@ class GroupServerWorkerStore(SQLBaseStore):
now = int(self._clock.time_msec())
if row and now < row["valid_until_ms"]:
- return json.loads(row["attestation_json"])
+ return db_to_json(row["attestation_json"])
return None
- def get_joined_groups(self, user_id):
- return self.db.simple_select_onecol(
+ async def get_joined_groups(self, user_id: str) -> List[str]:
+ return await self.db_pool.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join"},
retcol="group_id",
desc="get_joined_groups",
)
- def get_all_groups_for_user(self, user_id, now_token):
+ async def get_all_groups_for_user(self, user_id, now_token):
def _get_all_groups_for_user_txn(txn):
sql = """
SELECT group_id, type, membership, u.content
@@ -487,22 +493,22 @@ class GroupServerWorkerStore(SQLBaseStore):
"group_id": row[0],
"type": row[1],
"membership": row[2],
- "content": json.loads(row[3]),
+ "content": db_to_json(row[3]),
}
for row in txn
]
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_groups_for_user", _get_all_groups_for_user_txn
)
- def get_groups_changes_for_user(self, user_id, from_token, to_token):
+ async def get_groups_changes_for_user(self, user_id, from_token, to_token):
from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_entity_changed(
user_id, from_token
)
if not has_changed:
- return defer.succeed([])
+ return []
def _get_groups_changes_for_user_txn(txn):
sql = """
@@ -517,22 +523,44 @@ class GroupServerWorkerStore(SQLBaseStore):
"group_id": group_id,
"membership": membership,
"type": gtype,
- "content": json.loads(content_json),
+ "content": db_to_json(content_json),
}
for group_id, membership, gtype, content_json in txn
]
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_groups_changes_for_user", _get_groups_changes_for_user_txn
)
- def get_all_groups_changes(self, from_token, to_token, limit):
- from_token = int(from_token)
- has_changed = self._group_updates_stream_cache.has_any_entity_changed(
- from_token
- )
+ async def get_all_groups_changes(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for groups replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+
+ last_id = int(last_id)
+ has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
+
if not has_changed:
- return defer.succeed([])
+ return [], current_id, False
def _get_all_groups_changes_txn(txn):
sql = """
@@ -541,34 +569,61 @@ class GroupServerWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
LIMIT ?
"""
- txn.execute(sql, (from_token, to_token, limit))
- return [
- (stream_id, group_id, user_id, gtype, json.loads(content_json))
+ txn.execute(sql, (last_id, current_id, limit))
+ updates = [
+ (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn
]
- return self.db.runInteraction(
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
+
+ return updates, upto_token, limited
+
+ return await self.db_pool.runInteraction(
"get_all_groups_changes", _get_all_groups_changes_txn
)
class GroupServerStore(GroupServerWorkerStore):
- def set_group_join_policy(self, group_id, join_policy):
+ async def set_group_join_policy(self, group_id: str, join_policy: str) -> None:
"""Set the join policy of a group.
join_policy can be one of:
* "invite"
* "open"
"""
- return self.db.simple_update_one(
+ await self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues={"join_policy": join_policy},
desc="set_group_join_policy",
)
- def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
- return self.db.runInteraction(
+ async def add_room_to_summary(
+ self,
+ group_id: str,
+ room_id: str,
+ category_id: str,
+ order: int,
+ is_public: Optional[bool],
+ ) -> None:
+ """Add (or update) room's entry in summary.
+
+ Args:
+ group_id
+ room_id
+ category_id: If not None then adds the category to the end of
+ the summary if its not already there.
+ order: If not None inserts the room at that position, e.g. an order
+ of 1 will put the room first. Otherwise, the room gets added to
+ the end.
+ is_public
+ """
+ await self.db_pool.runInteraction(
"add_room_to_summary",
self._add_room_to_summary_txn,
group_id,
@@ -579,20 +634,28 @@ class GroupServerStore(GroupServerWorkerStore):
)
def _add_room_to_summary_txn(
- self, txn, group_id, room_id, category_id, order, is_public
- ):
+ self,
+ txn,
+ group_id: str,
+ room_id: str,
+ category_id: str,
+ order: int,
+ is_public: Optional[bool],
+ ) -> None:
"""Add (or update) room's entry in summary.
Args:
- group_id (str)
- room_id (str)
- category_id (str): If not None then adds the category to the end of
- the summary if its not already there. [Optional]
- order (int): If not None inserts the room at that position, e.g.
- an order of 1 will put the room first. Otherwise, the room gets
- added to the end.
+ txn
+ group_id
+ room_id
+ category_id: If not None then adds the category to the end of
+ the summary if its not already there.
+ order: If not None inserts the room at that position, e.g. an order
+ of 1 will put the room first. Otherwise, the room gets added to
+ the end.
+ is_public
"""
- room_in_group = self.db.simple_select_one_onecol_txn(
+ room_in_group = self.db_pool.simple_select_one_onecol_txn(
txn,
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
@@ -605,7 +668,7 @@ class GroupServerStore(GroupServerWorkerStore):
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
else:
- cat_exists = self.db.simple_select_one_onecol_txn(
+ cat_exists = self.db_pool.simple_select_one_onecol_txn(
txn,
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -616,7 +679,7 @@ class GroupServerStore(GroupServerWorkerStore):
raise SynapseError(400, "Category doesn't exist")
# TODO: Check category is part of summary already
- cat_exists = self.db.simple_select_one_onecol_txn(
+ cat_exists = self.db_pool.simple_select_one_onecol_txn(
txn,
table="group_summary_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -636,7 +699,7 @@ class GroupServerStore(GroupServerWorkerStore):
(group_id, category_id, group_id, category_id),
)
- existing = self.db.simple_select_one_txn(
+ existing = self.db_pool.simple_select_one_txn(
txn,
table="group_summary_rooms",
keyvalues={
@@ -669,7 +732,7 @@ class GroupServerStore(GroupServerWorkerStore):
to_update["room_order"] = order
if is_public is not None:
to_update["is_public"] = is_public
- self.db.simple_update_txn(
+ self.db_pool.simple_update_txn(
txn,
table="group_summary_rooms",
keyvalues={
@@ -683,7 +746,7 @@ class GroupServerStore(GroupServerWorkerStore):
if is_public is None:
is_public = True
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="group_summary_rooms",
values={
@@ -695,11 +758,13 @@ class GroupServerStore(GroupServerWorkerStore):
},
)
- def remove_room_from_summary(self, group_id, room_id, category_id):
+ async def remove_room_from_summary(
+ self, group_id: str, room_id: str, category_id: str
+ ) -> int:
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
- return self.db.simple_delete(
+ return await self.db_pool.simple_delete(
table="group_summary_rooms",
keyvalues={
"group_id": group_id,
@@ -709,7 +774,13 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_room_from_summary",
)
- def upsert_group_category(self, group_id, category_id, profile, is_public):
+ async def upsert_group_category(
+ self,
+ group_id: str,
+ category_id: str,
+ profile: Optional[JsonDict],
+ is_public: Optional[bool],
+ ) -> None:
"""Add/update room category for group
"""
insertion_values = {}
@@ -718,14 +789,14 @@ class GroupServerStore(GroupServerWorkerStore):
if profile is None:
insertion_values["profile"] = "{}"
else:
- update_values["profile"] = json.dumps(profile)
+ update_values["profile"] = json_encoder.encode(profile)
if is_public is None:
insertion_values["is_public"] = True
else:
update_values["is_public"] = is_public
- return self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
values=update_values,
@@ -733,14 +804,20 @@ class GroupServerStore(GroupServerWorkerStore):
desc="upsert_group_category",
)
- def remove_group_category(self, group_id, category_id):
- return self.db.simple_delete(
+ async def remove_group_category(self, group_id: str, category_id: str) -> int:
+ return await self.db_pool.simple_delete(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
desc="remove_group_category",
)
- def upsert_group_role(self, group_id, role_id, profile, is_public):
+ async def upsert_group_role(
+ self,
+ group_id: str,
+ role_id: str,
+ profile: Optional[JsonDict],
+ is_public: Optional[bool],
+ ) -> None:
"""Add/remove user role
"""
insertion_values = {}
@@ -749,14 +826,14 @@ class GroupServerStore(GroupServerWorkerStore):
if profile is None:
insertion_values["profile"] = "{}"
else:
- update_values["profile"] = json.dumps(profile)
+ update_values["profile"] = json_encoder.encode(profile)
if is_public is None:
insertion_values["is_public"] = True
else:
update_values["is_public"] = is_public
- return self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
values=update_values,
@@ -764,15 +841,34 @@ class GroupServerStore(GroupServerWorkerStore):
desc="upsert_group_role",
)
- def remove_group_role(self, group_id, role_id):
- return self.db.simple_delete(
+ async def remove_group_role(self, group_id: str, role_id: str) -> int:
+ return await self.db_pool.simple_delete(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
desc="remove_group_role",
)
- def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
- return self.db.runInteraction(
+ async def add_user_to_summary(
+ self,
+ group_id: str,
+ user_id: str,
+ role_id: str,
+ order: int,
+ is_public: Optional[bool],
+ ) -> None:
+ """Add (or update) user's entry in summary.
+
+ Args:
+ group_id
+ user_id
+ role_id: If not None then adds the role to the end of the summary if
+ its not already there.
+ order: If not None inserts the user at that position, e.g. an order
+ of 1 will put the user first. Otherwise, the user gets added to
+ the end.
+ is_public
+ """
+ await self.db_pool.runInteraction(
"add_user_to_summary",
self._add_user_to_summary_txn,
group_id,
@@ -783,20 +879,28 @@ class GroupServerStore(GroupServerWorkerStore):
)
def _add_user_to_summary_txn(
- self, txn, group_id, user_id, role_id, order, is_public
+ self,
+ txn,
+ group_id: str,
+ user_id: str,
+ role_id: str,
+ order: int,
+ is_public: Optional[bool],
):
"""Add (or update) user's entry in summary.
Args:
- group_id (str)
- user_id (str)
- role_id (str): If not None then adds the role to the end of
- the summary if its not already there. [Optional]
- order (int): If not None inserts the user at that position, e.g.
- an order of 1 will put the user first. Otherwise, the user gets
- added to the end.
+ txn
+ group_id
+ user_id
+ role_id: If not None then adds the role to the end of the summary if
+ its not already there.
+ order: If not None inserts the user at that position, e.g. an order
+ of 1 will put the user first. Otherwise, the user gets added to
+ the end.
+ is_public
"""
- user_in_group = self.db.simple_select_one_onecol_txn(
+ user_in_group = self.db_pool.simple_select_one_onecol_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -809,7 +913,7 @@ class GroupServerStore(GroupServerWorkerStore):
if role_id is None:
role_id = _DEFAULT_ROLE_ID
else:
- role_exists = self.db.simple_select_one_onecol_txn(
+ role_exists = self.db_pool.simple_select_one_onecol_txn(
txn,
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -820,7 +924,7 @@ class GroupServerStore(GroupServerWorkerStore):
raise SynapseError(400, "Role doesn't exist")
# TODO: Check role is part of the summary already
- role_exists = self.db.simple_select_one_onecol_txn(
+ role_exists = self.db_pool.simple_select_one_onecol_txn(
txn,
table="group_summary_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -840,7 +944,7 @@ class GroupServerStore(GroupServerWorkerStore):
(group_id, role_id, group_id, role_id),
)
- existing = self.db.simple_select_one_txn(
+ existing = self.db_pool.simple_select_one_txn(
txn,
table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
@@ -869,7 +973,7 @@ class GroupServerStore(GroupServerWorkerStore):
to_update["user_order"] = order
if is_public is not None:
to_update["is_public"] = is_public
- self.db.simple_update_txn(
+ self.db_pool.simple_update_txn(
txn,
table="group_summary_users",
keyvalues={
@@ -883,7 +987,7 @@ class GroupServerStore(GroupServerWorkerStore):
if is_public is None:
is_public = True
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="group_summary_users",
values={
@@ -895,50 +999,51 @@ class GroupServerStore(GroupServerWorkerStore):
},
)
- def remove_user_from_summary(self, group_id, user_id, role_id):
+ async def remove_user_from_summary(
+ self, group_id: str, user_id: str, role_id: str
+ ) -> int:
if role_id is None:
role_id = _DEFAULT_ROLE_ID
- return self.db.simple_delete(
+ return await self.db_pool.simple_delete(
table="group_summary_users",
keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
desc="remove_user_from_summary",
)
- def add_group_invite(self, group_id, user_id):
+ async def add_group_invite(self, group_id: str, user_id: str) -> None:
"""Record that the group server has invited a user
"""
- return self.db.simple_insert(
+ await self.db_pool.simple_insert(
table="group_invites",
values={"group_id": group_id, "user_id": user_id},
desc="add_group_invite",
)
- def add_user_to_group(
+ async def add_user_to_group(
self,
- group_id,
- user_id,
- is_admin=False,
- is_public=True,
- local_attestation=None,
- remote_attestation=None,
- ):
+ group_id: str,
+ user_id: str,
+ is_admin: bool = False,
+ is_public: bool = True,
+ local_attestation: dict = None,
+ remote_attestation: dict = None,
+ ) -> None:
"""Add a user to the group server.
Args:
- group_id (str)
- user_id (str)
- is_admin (bool)
- is_public (bool)
- local_attestation (dict): The attestation the GS created to give
- to the remote server. Optional if the user and group are on the
- same server
- remote_attestation (dict): The attestation given to GS by remote
+ group_id
+ user_id
+ is_admin
+ is_public
+ local_attestation: The attestation the GS created to give to the remote
server. Optional if the user and group are on the same server
+ remote_attestation: The attestation given to GS by remote server.
+ Optional if the user and group are on the same server
"""
def _add_user_to_group_txn(txn):
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="group_users",
values={
@@ -949,14 +1054,14 @@ class GroupServerStore(GroupServerWorkerStore):
},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
)
if local_attestation:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="group_attestations_renewals",
values={
@@ -966,137 +1071,145 @@ class GroupServerStore(GroupServerWorkerStore):
},
)
if remote_attestation:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="group_attestations_remote",
values={
"group_id": group_id,
"user_id": user_id,
"valid_until_ms": remote_attestation["valid_until_ms"],
- "attestation_json": json.dumps(remote_attestation),
+ "attestation_json": json_encoder.encode(remote_attestation),
},
)
- return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn)
+ await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
- def remove_user_from_group(self, group_id, user_id):
+ async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
def _remove_user_from_group_txn(txn):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"remove_user_from_group", _remove_user_from_group_txn
)
- def change_user_admin_in_group(self, group_id, user_id, is_admin):
- return self.db.simple_update(
+ async def change_user_admin_in_group(
+ self, group_id: str, user_id: str, is_admin: bool
+ ) -> int:
+ return await self.db_pool.simple_update(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_admin": is_admin},
desc="change_user_admin_in_group"
)
- def add_room_to_group(self, group_id, room_id, is_public):
- return self.db.simple_insert(
+ async def add_room_to_group(
+ self, group_id: str, room_id: str, is_public: bool
+ ) -> None:
+ await self.db_pool.simple_insert(
table="group_rooms",
values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
desc="add_room_to_group",
)
- def update_room_in_group_visibility(self, group_id, room_id, is_public):
- return self.db.simple_update(
+ async def update_room_in_group_visibility(
+ self, group_id: str, room_id: str, is_public: bool
+ ) -> int:
+ return await self.db_pool.simple_update(
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
updatevalues={"is_public": is_public},
desc="update_room_in_group_visibility",
)
- def remove_room_from_group(self, group_id, room_id):
+ async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
def _remove_room_from_group_txn(txn):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_summary_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
)
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"remove_room_from_group", _remove_room_from_group_txn
)
- def update_group_publicity(self, group_id, user_id, publicise):
+ async def update_group_publicity(
+ self, group_id: str, user_id: str, publicise: bool
+ ) -> None:
"""Update whether the user is publicising their membership of the group
"""
- return self.db.simple_update_one(
+ await self.db_pool.simple_update_one(
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_publicised": publicise},
desc="update_group_publicity",
)
- @defer.inlineCallbacks
- def register_user_group_membership(
+ async def register_user_group_membership(
self,
- group_id,
- user_id,
- membership,
- is_admin=False,
- content={},
- local_attestation=None,
- remote_attestation=None,
- is_publicised=False,
- ):
+ group_id: str,
+ user_id: str,
+ membership: str,
+ is_admin: bool = False,
+ content: JsonDict = {},
+ local_attestation: Optional[dict] = None,
+ remote_attestation: Optional[dict] = None,
+ is_publicised: bool = False,
+ ) -> int:
"""Registers that a local user is a member of a (local or remote) group.
Args:
- group_id (str)
- user_id (str)
- membership (str)
- is_admin (bool)
- content (dict): Content of the membership, e.g. includes the inviter
+ group_id: The group the member is being added to.
+ user_id: THe user ID to add to the group.
+ membership: The type of group membership.
+ is_admin: Whether the user should be added as a group admin.
+ content: Content of the membership, e.g. includes the inviter
if the user has been invited.
- local_attestation (dict): If remote group then store the fact that we
+ local_attestation: If remote group then store the fact that we
have given out an attestation, else None.
- remote_attestation (dict): If remote group then store the remote
+ remote_attestation: If remote group then store the remote
attestation from the group, else None.
+ is_publicised: Whether this should be publicised.
"""
def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert?
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="local_group_membership",
values={
@@ -1105,11 +1218,11 @@ class GroupServerStore(GroupServerWorkerStore):
"is_admin": is_admin,
"membership": membership,
"is_publicised": is_publicised,
- "content": json.dumps(content),
+ "content": json_encoder.encode(content),
},
)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="local_group_updates",
values={
@@ -1117,7 +1230,7 @@ class GroupServerStore(GroupServerWorkerStore):
"group_id": group_id,
"user_id": user_id,
"type": "membership",
- "content": json.dumps(
+ "content": json_encoder.encode(
{"membership": membership, "content": content}
),
},
@@ -1128,7 +1241,7 @@ class GroupServerStore(GroupServerWorkerStore):
if membership == "join":
if local_attestation:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="group_attestations_renewals",
values={
@@ -1138,23 +1251,23 @@ class GroupServerStore(GroupServerWorkerStore):
},
)
if remote_attestation:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="group_attestations_remote",
values={
"group_id": group_id,
"user_id": user_id,
"valid_until_ms": remote_attestation["valid_until_ms"],
- "attestation_json": json.dumps(remote_attestation),
+ "attestation_json": json_encoder.encode(remote_attestation),
},
)
else:
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -1162,19 +1275,18 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- with self._group_updates_id_gen.get_next() as next_id:
- res = yield self.db.runInteraction(
+ with await self._group_updates_id_gen.get_next() as next_id:
+ res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
next_id,
)
return res
- @defer.inlineCallbacks
- def create_group(
+ async def create_group(
self, group_id, user_id, name, avatar_url, short_description, long_description
- ):
- yield self.db.simple_insert(
+ ) -> None:
+ await self.db_pool.simple_insert(
table="groups",
values={
"group_id": group_id,
@@ -1187,48 +1299,51 @@ class GroupServerStore(GroupServerWorkerStore):
desc="create_group",
)
- @defer.inlineCallbacks
- def update_group_profile(self, group_id, profile):
- yield self.db.simple_update_one(
+ async def update_group_profile(self, group_id, profile):
+ await self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues=profile,
desc="update_group_profile",
)
- def update_attestation_renewal(self, group_id, user_id, attestation):
+ async def update_attestation_renewal(
+ self, group_id: str, user_id: str, attestation: dict
+ ) -> None:
"""Update an attestation that we have renewed
"""
- return self.db.simple_update_one(
+ await self.db_pool.simple_update_one(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
desc="update_attestation_renewal",
)
- def update_remote_attestion(self, group_id, user_id, attestation):
+ async def update_remote_attestion(
+ self, group_id: str, user_id: str, attestation: dict
+ ) -> None:
"""Update an attestation that a remote has renewed
"""
- return self.db.simple_update_one(
+ await self.db_pool.simple_update_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={
"valid_until_ms": attestation["valid_until_ms"],
- "attestation_json": json.dumps(attestation),
+ "attestation_json": json_encoder.encode(attestation),
},
desc="update_remote_attestion",
)
- def remove_attestation_renewal(self, group_id, user_id):
+ async def remove_attestation_renewal(self, group_id: str, user_id: str) -> int:
"""Remove an attestation that we thought we should renew, but actually
shouldn't. Ideally this would never get called as we would never
incorrectly try and do attestations for local users on local groups.
Args:
- group_id (str)
- user_id (str)
+ group_id
+ user_id
"""
- return self.db.simple_delete(
+ return await self.db_pool.simple_delete(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
desc="remove_attestation_renewal",
@@ -1237,14 +1352,11 @@ class GroupServerStore(GroupServerWorkerStore):
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
- def delete_group(self, group_id):
+ async def delete_group(self, group_id: str) -> None:
"""Deletes a group fully from the database.
Args:
- group_id (str)
-
- Returns:
- Deferred
+ group_id: The group ID to delete.
"""
def _delete_group_txn(txn):
@@ -1264,8 +1376,8 @@ class GroupServerStore(GroupServerWorkerStore):
]
for table in tables:
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table=table, keyvalues={"group_id": group_id}
)
- return self.db.runInteraction("delete_group", _delete_group_txn)
+ await self.db_pool.runInteraction("delete_group", _delete_group_txn)
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/databases/main/keys.py
index 4e1642a27a..ad43bb05ab 100644
--- a/synapse/storage/data_stores/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,6 +16,7 @@
import itertools
import logging
+from typing import Dict, Iterable, List, Optional, Tuple
from signedjson.key import decode_verify_key_bytes
@@ -41,16 +42,17 @@ class KeyStore(SQLBaseStore):
@cachedList(
cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
)
- def get_server_verify_keys(self, server_name_and_key_ids):
+ async def get_server_verify_keys(
+ self, server_name_and_key_ids: Iterable[Tuple[str, str]]
+ ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
"""
Args:
- server_name_and_key_ids (iterable[Tuple[str, str]]):
+ server_name_and_key_ids:
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
- Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
- map from (server_name, key_id) -> FetchKeyResult, or None if the key is
- unknown
+ A map from (server_name, key_id) -> FetchKeyResult, or None if the
+ key is unknown
"""
keys = {}
@@ -86,14 +88,19 @@ class KeyStore(SQLBaseStore):
_get_keys(txn, batch)
return keys
- return self.db.runInteraction("get_server_verify_keys", _txn)
+ return await self.db_pool.runInteraction("get_server_verify_keys", _txn)
- def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+ async def store_server_verify_keys(
+ self,
+ from_server: str,
+ ts_added_ms: int,
+ verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
+ ) -> None:
"""Stores NACL verification keys for remote servers.
Args:
- from_server (str): Where the verification keys were looked up
- ts_added_ms (int): The time to record that the key was added
- verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+ from_server: Where the verification keys were looked up
+ ts_added_ms: The time to record that the key was added
+ verify_keys:
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
"""
@@ -115,15 +122,9 @@ class KeyStore(SQLBaseStore):
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))
- def _invalidate(res):
- f = self._get_server_verify_key.invalidate
- for i in invalidations:
- f((i,))
- return res
-
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"store_server_verify_keys",
- self.db.simple_upsert_many_txn,
+ self.db_pool.simple_upsert_many_txn,
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
@@ -134,24 +135,34 @@ class KeyStore(SQLBaseStore):
"verify_key",
),
value_values=value_values,
- ).addCallback(_invalidate)
+ )
- def store_server_keys_json(
- self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
- ):
+ invalidate = self._get_server_verify_key.invalidate
+ for i in invalidations:
+ invalidate((i,))
+
+ async def store_server_keys_json(
+ self,
+ server_name: str,
+ key_id: str,
+ from_server: str,
+ ts_now_ms: int,
+ ts_expires_ms: int,
+ key_json_bytes: bytes,
+ ) -> None:
"""Stores the JSON bytes for a set of keys from a server
The JSON should be signed by the originating server, the intermediate
server, and by this server. Updates the value for the
(server_name, key_id, from_server) triplet if one already existed.
Args:
- server_name (str): The name of the server.
- key_id (str): The identifer of the key this JSON is for.
- from_server (str): The server this JSON was fetched from.
- ts_now_ms (int): The time now in milliseconds.
- ts_valid_until_ms (int): The time when this json stops being valid.
- key_json (bytes): The encoded JSON.
+ server_name: The name of the server.
+ key_id: The identifer of the key this JSON is for.
+ from_server: The server this JSON was fetched from.
+ ts_now_ms: The time now in milliseconds.
+ ts_valid_until_ms: The time when this json stops being valid.
+ key_json_bytes: The encoded JSON.
"""
- return self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
table="server_keys_json",
keyvalues={
"server_name": server_name,
@@ -169,7 +180,9 @@ class KeyStore(SQLBaseStore):
desc="store_server_keys_json",
)
- def get_server_keys_json(self, server_keys):
+ async def get_server_keys_json(
+ self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
+ ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
"""Retrive the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
@@ -178,8 +191,7 @@ class KeyStore(SQLBaseStore):
Args:
server_keys (list): List of (server_name, key_id, source) triplets.
Returns:
- Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
- Dict mapping (server_name, key_id, source) triplets to lists of dicts
+ A mapping from (server_name, key_id, source) triplets to a list of dicts
"""
def _get_server_keys_json_txn(txn):
@@ -190,7 +202,7 @@ class KeyStore(SQLBaseStore):
keyvalues["key_id"] = key_id
if from_server is not None:
keyvalues["from_server"] = from_server
- rows = self.db.simple_select_list_txn(
+ rows = self.db_pool.simple_select_list_txn(
txn,
"server_keys_json",
keyvalues=keyvalues,
@@ -205,4 +217,6 @@ class KeyStore(SQLBaseStore):
results[(server_name, key_id, from_server)] = rows
return results
- return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
+ return await self.db_pool.runInteraction(
+ "get_server_keys_json", _get_server_keys_json_txn
+ )
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 8aecd414c2..86557d5512 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -12,17 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(MediaRepositoryBackgroundUpdateStore, self).__init__(
database, db_conn, hs
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
update_name="local_media_repository_url_idx",
index_name="local_media_repository_url_idx",
table="local_media_repository",
@@ -34,15 +36,16 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
- def get_local_media(self, media_id):
+ async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media
+
Returns:
None if the media_id doesn't exist.
"""
- return self.db.simple_select_one(
+ return await self.db_pool.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -57,7 +60,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_local_media",
)
- def store_local_media(
+ async def store_local_media(
self,
media_id,
media_type,
@@ -66,8 +69,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length,
user_id,
url_cache=None,
- ):
- return self.db.simple_insert(
+ ) -> None:
+ await self.db_pool.simple_insert(
"local_media_repository",
{
"media_id": media_id,
@@ -81,7 +84,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_media",
)
- def get_url_cache(self, url, ts):
+ async def mark_local_media_as_safe(self, media_id: str) -> None:
+ """Mark a local media as safe from quarantining."""
+ await self.db_pool.simple_update_one(
+ table="local_media_repository",
+ keyvalues={"media_id": media_id},
+ updatevalues={"safe_from_quarantine": True},
+ desc="mark_local_media_as_safe",
+ )
+
+ async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
None if the URL isn't cached.
@@ -127,12 +139,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
)
- return self.db.runInteraction("get_url_cache", get_url_cache_txn)
+ return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
- def store_url_cache(
+ async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
- return self.db.simple_insert(
+ await self.db_pool.simple_insert(
"local_media_repository_url_cache",
{
"url": url,
@@ -146,8 +158,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_url_cache",
)
- def get_local_media_thumbnails(self, media_id):
- return self.db.simple_select_list(
+ async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]:
+ return await self.db_pool.simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
@@ -160,7 +172,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_local_media_thumbnails",
)
- def store_local_thumbnail(
+ async def store_local_thumbnail(
self,
media_id,
thumbnail_width,
@@ -169,7 +181,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self.db.simple_insert(
+ await self.db_pool.simple_insert(
"local_media_repository_thumbnails",
{
"media_id": media_id,
@@ -182,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_thumbnail",
)
- def get_cached_remote_media(self, origin, media_id):
- return self.db.simple_select_one(
+ async def get_cached_remote_media(
+ self, origin, media_id: str
+ ) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@@ -198,7 +212,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_cached_remote_media",
)
- def store_cached_remote_media(
+ async def store_cached_remote_media(
self,
origin,
media_id,
@@ -208,7 +222,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
upload_name,
filesystem_id,
):
- return self.db.simple_insert(
+ await self.db_pool.simple_insert(
"remote_media_cache",
{
"media_origin": origin,
@@ -223,12 +237,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_cached_remote_media",
)
- def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+ async def update_cached_last_access_time(
+ self,
+ local_media: Iterable[str],
+ remote_media: Iterable[Tuple[str, str]],
+ time_ms: int,
+ ):
"""Updates the last access time of the given media
Args:
- local_media (iterable[str]): Set of media_ids
- remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+ local_media: Set of media_ids
+ remote_media: Set of (server_name, media_id)
time_ms: Current time in milliseconds
"""
@@ -253,12 +272,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
)
- def get_remote_media_thumbnails(self, origin, media_id):
- return self.db.simple_select_list(
+ async def get_remote_media_thumbnails(
+ self, origin: str, media_id: str
+ ) -> List[Dict[str, Any]]:
+ return await self.db_pool.simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
@@ -272,7 +293,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_remote_media_thumbnails",
)
- def store_remote_media_thumbnail(
+ async def store_remote_media_thumbnail(
self,
origin,
media_id,
@@ -283,7 +304,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self.db.simple_insert(
+ await self.db_pool.simple_insert(
"remote_media_cache_thumbnails",
{
"media_origin": origin,
@@ -298,33 +319,35 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_remote_media_thumbnail",
)
- def get_remote_media_before(self, before_ts):
+ async def get_remote_media_before(self, before_ts):
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)
- return self.db.execute(
- "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts
+ return await self.db_pool.execute(
+ "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
)
- def delete_remote_media(self, media_origin, media_id):
+ async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
def delete_remote_media_txn(txn):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
"remote_media_cache",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
"remote_media_cache_thumbnails",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
- return self.db.runInteraction("delete_remote_media", delete_remote_media_txn)
+ await self.db_pool.runInteraction(
+ "delete_remote_media", delete_remote_media_txn
+ )
- def get_expired_url_cache(self, now_ts):
+ async def get_expired_url_cache(self, now_ts: int) -> List[str]:
sql = (
"SELECT media_id FROM local_media_repository_url_cache"
" WHERE expires_ts < ?"
@@ -336,7 +359,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_expired_url_cache", _get_expired_url_cache_txn
)
@@ -349,9 +372,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
+ return await self.db_pool.runInteraction(
+ "delete_url_cache", _delete_url_cache_txn
+ )
- def get_url_cache_media_before(self, before_ts):
+ async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
sql = (
"SELECT media_id FROM local_media_repository"
" WHERE created_ts < ? AND url_cache IS NOT NULL"
@@ -363,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
@@ -380,6 +405,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
)
diff --git a/synapse/storage/data_stores/main/metrics.py b/synapse/storage/databases/main/metrics.py
index dad5bbc602..686052bd83 100644
--- a/synapse/storage/data_stores/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -15,15 +15,13 @@
import typing
from collections import Counter
-from twisted.internet import defer
-
from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.event_push_actions import (
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
-from synapse.storage.database import Database
class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
@@ -31,7 +29,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
stats and prometheus metrics.
"""
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
# Collect metrics on the number of forward extremities that exist.
@@ -66,11 +64,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
)
return txn.fetchall()
- res = await self.db.runInteraction("read_forward_extremities", fetch)
+ res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
self._current_forward_extremities_amount = Counter([x[0] for x in res])
- @defer.inlineCallbacks
- def count_daily_messages(self):
+ async def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
@@ -88,11 +85,9 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(count,) = txn.fetchone()
return count
- ret = yield self.db.runInteraction("count_messages", _count_messages)
- return ret
+ return await self.db_pool.runInteraction("count_messages", _count_messages)
- @defer.inlineCallbacks
- def count_daily_sent_messages(self):
+ async def count_daily_sent_messages(self):
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then thats your own fault.
@@ -109,11 +104,11 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(count,) = txn.fetchone()
return count
- ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
- return ret
+ return await self.db_pool.runInteraction(
+ "count_daily_sent_messages", _count_messages
+ )
- @defer.inlineCallbacks
- def count_daily_active_rooms(self):
+ async def count_daily_active_rooms(self):
def _count(txn):
sql = """
SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
@@ -124,5 +119,4 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(count,) = txn.fetchone()
return count
- ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
- return ret
+ return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e459cf49a0..1d793d3deb 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -13,12 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List
-
-from twisted.internet import defer
+from typing import Dict, List
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database, make_in_list_sql_clause
+from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -29,17 +27,17 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersWorkerStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
@cached(num_args=0)
- def get_monthly_active_count(self):
+ async def get_monthly_active_count(self) -> int:
"""Generates current count of monthly active users
Returns:
- Defered[int]: Number of current monthly active users
+ Number of current monthly active users
"""
def _count_users(txn):
@@ -48,10 +46,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
(count,) = txn.fetchone()
return count
- return self.db.runInteraction("count_users", _count_users)
+ return await self.db_pool.runInteraction("count_users", _count_users)
@cached(num_args=0)
- def get_monthly_active_count_by_service(self):
+ async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
"""Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table
@@ -59,8 +57,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
method to return anything other than native matrix users.
Returns:
- Deferred[dict]: dict that includes a mapping between app_service_id
- and the number of occurrences.
+ A mapping between app_service_id and the number of occurrences.
"""
@@ -76,7 +73,9 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
result = txn.fetchall()
return dict(result)
- return self.db.runInteraction("count_users_by_service", _count_users_by_service)
+ return await self.db_pool.runInteraction(
+ "count_users_by_service", _count_users_by_service
+ )
async def get_registered_reserved_users(self) -> List[str]:
"""Of the reserved threepids defined in config, retrieve those that are associated
@@ -99,17 +98,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
return users
@cached(num_args=1)
- def user_last_seen_monthly_active(self, user_id):
+ async def user_last_seen_monthly_active(self, user_id: str) -> int:
"""
- Checks if a given user is part of the monthly active user group
- Arguments:
- user_id (str): user to add/update
- Return:
- Deferred[int] : timestamp since last seen, None if never seen
+ Checks if a given user is part of the monthly active user group
+
+ Arguments:
+ user_id: user to add/update
+ Return:
+ Timestamp since last seen, None if never seen
"""
- return self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="monthly_active_users",
keyvalues={"user_id": user_id},
retcol="timestamp",
@@ -119,7 +119,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
@@ -128,7 +128,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# Do not add more reserved users than the total allowable number
# cur = LoggingTransaction(
- self.db.new_transaction(
+ self.db_pool.new_transaction(
db_conn,
"initialise_mau_threepids",
[],
@@ -162,7 +162,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
is_support = self.is_support_user_txn(txn, user_id)
if not is_support:
# We do this manually here to avoid hitting #6791
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="monthly_active_users",
keyvalues={"user_id": user_id},
@@ -246,20 +246,16 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
reserved_users = await self.get_registered_reserved_users()
- await self.db.runInteraction(
+ await self.db_pool.runInteraction(
"reap_monthly_active_users", _reap_users, reserved_users
)
- @defer.inlineCallbacks
- def upsert_monthly_active_user(self, user_id):
+ async def upsert_monthly_active_user(self, user_id: str) -> None:
"""Updates or inserts the user into the monthly active user table, which
is used to track the current MAU usage of the server
Args:
- user_id (str): user to add/update
-
- Returns:
- Deferred
+ user_id: user to add/update
"""
# Support user never to be included in MAU stats. Note I can't easily call this
# from upsert_monthly_active_user_txn because then I need a _txn form of
@@ -269,11 +265,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# _initialise_reserved_users reasoning that it would be very strange to
# include a support user in this context.
- is_support = yield self.is_support_user(user_id)
+ is_support = await self.is_support_user(user_id)
if is_support:
return
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
@@ -303,7 +299,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# never be a big table and alternative approaches (batching multiple
# upserts into a single txn) introduced a lot of extra complexity.
# See https://github.com/matrix-org/synapse/issues/3854 for more
- is_insert = self.db.simple_upsert_txn(
+ is_insert = self.db_pool.simple_upsert_txn(
txn,
table="monthly_active_users",
keyvalues={"user_id": user_id},
@@ -320,8 +316,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
return is_insert
- @defer.inlineCallbacks
- def populate_monthly_active_users(self, user_id):
+ async def populate_monthly_active_users(self, user_id):
"""Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables
@@ -330,14 +325,14 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"""
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
- is_guest = yield self.is_guest(user_id)
+ is_guest = await self.is_guest(user_id)
if is_guest:
return
- is_trial = yield self.is_trial_user(user_id)
+ is_trial = await self.is_trial_user(user_id)
if is_trial:
return
- last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
+ last_seen_timestamp = await self.user_last_seen_monthly_active(user_id)
now = self.hs.get_clock().time_msec()
# We want to reduce to the total number of db writes, and are happy
@@ -350,10 +345,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# False, there is no point in checking get_monthly_active_count - it
# adds no value and will break the logic if max_mau_value is exceeded.
if not self._limit_usage_by_mau:
- yield self.upsert_monthly_active_user(user_id)
+ await self.upsert_monthly_active_user(user_id)
else:
- count = yield self.get_monthly_active_count()
+ count = await self.get_monthly_active_count()
if count < self._max_mau_value:
- yield self.upsert_monthly_active_user(user_id)
+ await self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
- yield self.upsert_monthly_active_user(user_id)
+ await self.upsert_monthly_active_user(user_id)
diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/databases/main/openid.py
index cc21437e92..2aac64901b 100644
--- a/synapse/storage/data_stores/main/openid.py
+++ b/synapse/storage/databases/main/openid.py
@@ -1,9 +1,13 @@
+from typing import Optional
+
from synapse.storage._base import SQLBaseStore
class OpenIdStore(SQLBaseStore):
- def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
- return self.db.simple_insert(
+ async def insert_open_id_token(
+ self, token: str, ts_valid_until_ms: int, user_id: str
+ ) -> None:
+ await self.db_pool.simple_insert(
table="open_id_tokens",
values={
"token": token,
@@ -13,7 +17,9 @@ class OpenIdStore(SQLBaseStore):
desc="insert_open_id_token",
)
- def get_user_id_for_open_id_token(self, token, ts_now_ms):
+ async def get_user_id_for_open_id_token(
+ self, token: str, ts_now_ms: int
+ ) -> Optional[str]:
def get_user_id_for_token_txn(txn):
sql = (
"SELECT user_id FROM open_id_tokens"
@@ -28,6 +34,6 @@ class OpenIdStore(SQLBaseStore):
else:
return rows[0][0]
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_user_id_for_token", get_user_id_for_token_txn
)
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/databases/main/presence.py
index dab31e0c2d..c9f655dfb7 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -13,23 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+from typing import List, Tuple
+from synapse.api.presence import UserPresenceState
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
- @defer.inlineCallbacks
- def update_presence(self, presence_states):
- stream_ordering_manager = self._presence_id_gen.get_next_mult(
+ async def update_presence(self, presence_states):
+ stream_ordering_manager = await self._presence_id_gen.get_next_mult(
len(presence_states)
)
with stream_ordering_manager as stream_orderings:
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"update_presence",
self._update_presence_txn,
stream_orderings,
@@ -46,7 +45,7 @@ class PresenceStore(SQLBaseStore):
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
# Actually insert new rows
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="presence_stream",
values=[
@@ -73,9 +72,32 @@ class PresenceStore(SQLBaseStore):
)
txn.execute(sql + clause, [stream_id] + list(args))
- def get_all_presence_updates(self, last_id, current_id, limit):
+ async def get_all_presence_updates(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
+ """Get updates for presence replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+
if last_id == current_id:
- return defer.succeed([])
+ return [], current_id, False
def get_all_presence_updates_txn(txn):
sql = """
@@ -89,9 +111,17 @@ class PresenceStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
- return txn.fetchall()
+ updates = [(row[0], row[1:]) for row in txn]
+
+ upper_bound = current_id
+ limited = False
+ if len(updates) >= limit:
+ upper_bound = updates[-1][0]
+ limited = True
- return self.db.runInteraction(
+ return updates, upper_bound, limited
+
+ return await self.db_pool.runInteraction(
"get_all_presence_updates", get_all_presence_updates_txn
)
@@ -100,13 +130,10 @@ class PresenceStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_presence_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
+ cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
)
- def get_presence_for_users(self, user_ids):
- rows = yield self.db.simple_select_many_batch(
+ async def get_presence_for_users(self, user_ids):
+ rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
@@ -130,24 +157,3 @@ class PresenceStore(SQLBaseStore):
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
-
- def allow_presence_visible(self, observed_localpart, observer_userid):
- return self.db.simple_insert(
- table="presence_allow_inbound",
- values={
- "observed_user_id": observed_localpart,
- "observer_user_id": observer_userid,
- },
- desc="allow_presence_visible",
- or_ignore=True,
- )
-
- def disallow_presence_visible(self, observed_localpart, observer_userid):
- return self.db.simple_delete_one(
- table="presence_allow_inbound",
- keyvalues={
- "observed_user_id": observed_localpart,
- "observer_user_id": observer_userid,
- },
- desc="disallow_presence_visible",
- )
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/databases/main/profile.py
index bfc9369f0b..d2e0685e9e 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -12,19 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
+from typing import Any, Dict, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.roommember import ProfileInfo
+from synapse.storage.databases.main.roommember import ProfileInfo
class ProfileWorkerStore(SQLBaseStore):
- @defer.inlineCallbacks
- def get_profileinfo(self, user_localpart):
+ async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
try:
- profile = yield self.db.simple_select_one(
+ profile = await self.db_pool.simple_select_one(
table="profiles",
keyvalues={"user_id": user_localpart},
retcols=("displayname", "avatar_url"),
@@ -41,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
- def get_profile_displayname(self, user_localpart):
- return self.db.simple_select_one_onecol(
+ async def get_profile_displayname(self, user_localpart: str) -> str:
+ return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)
- def get_profile_avatar_url(self, user_localpart):
- return self.db.simple_select_one_onecol(
+ async def get_profile_avatar_url(self, user_localpart: str) -> str:
+ return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
- def get_from_remote_profile_cache(self, user_id):
- return self.db.simple_select_one(
+ async def get_from_remote_profile_cache(
+ self, user_id: str
+ ) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),
@@ -66,21 +66,25 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_from_remote_profile_cache",
)
- def create_profile(self, user_localpart):
- return self.db.simple_insert(
+ async def create_profile(self, user_localpart: str) -> None:
+ await self.db_pool.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
- def set_profile_displayname(self, user_localpart, new_displayname):
- return self.db.simple_update_one(
+ async def set_profile_displayname(
+ self, user_localpart: str, new_displayname: str
+ ) -> None:
+ await self.db_pool.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname},
desc="set_profile_displayname",
)
- def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self.db.simple_update_one(
+ async def set_profile_avatar_url(
+ self, user_localpart: str, new_avatar_url: str
+ ) -> None:
+ await self.db_pool.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"avatar_url": new_avatar_url},
@@ -89,13 +93,15 @@ class ProfileWorkerStore(SQLBaseStore):
class ProfileStore(ProfileWorkerStore):
- def add_remote_profile_cache(self, user_id, displayname, avatar_url):
+ async def add_remote_profile_cache(
+ self, user_id: str, displayname: str, avatar_url: str
+ ) -> None:
"""Ensure we are caching the remote user's profiles.
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
- return self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
@@ -106,8 +112,10 @@ class ProfileStore(ProfileWorkerStore):
desc="add_remote_profile_cache",
)
- def update_remote_profile_cache(self, user_id, displayname, avatar_url):
- return self.db.simple_update(
+ async def update_remote_profile_cache(
+ self, user_id: str, displayname: str, avatar_url: str
+ ) -> int:
+ return await self.db_pool.simple_update(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
updatevalues={
@@ -118,20 +126,21 @@ class ProfileStore(ProfileWorkerStore):
desc="update_remote_profile_cache",
)
- @defer.inlineCallbacks
- def maybe_delete_remote_profile_cache(self, user_id):
+ async def maybe_delete_remote_profile_cache(self, user_id):
"""Check if we still care about the remote user's profile, and if we
don't then remove their profile from the cache
"""
- subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
+ subscribed = await self.is_subscribed_remote_profile_for_user(user_id)
if not subscribed:
- yield self.db.simple_delete(
+ await self.db_pool.simple_delete(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
desc="delete_remote_profile_cache",
)
- def get_remote_profile_cache_entries_that_expire(self, last_checked):
+ async def get_remote_profile_cache_entries_that_expire(
+ self, last_checked: int
+ ) -> Dict[str, str]:
"""Get all users who haven't been checked since `last_checked`
"""
@@ -144,18 +153,17 @@ class ProfileStore(ProfileWorkerStore):
txn.execute(sql, (last_checked,))
- return self.db.cursor_to_dict(txn)
+ return self.db_pool.cursor_to_dict(txn)
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
- @defer.inlineCallbacks
- def is_subscribed_remote_profile_for_user(self, user_id):
+ async def is_subscribed_remote_profile_for_user(self, user_id):
"""Check whether we are interested in a remote user's profile.
"""
- res = yield self.db.simple_select_one_onecol(
+ res = await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
retcol="user_id",
@@ -166,7 +174,7 @@ class ProfileStore(ProfileWorkerStore):
if res:
return True
- res = yield self.db.simple_select_one_onecol(
+ res = await self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"user_id": user_id},
retcol="user_id",
diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index a93e1ef198..ea833829ae 100644
--- a/synapse/storage/data_stores/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -14,36 +14,35 @@
# limitations under the License.
import logging
-from typing import Any, Tuple
+from typing import Any, List, Set, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken
logger = logging.getLogger(__name__)
class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
- def purge_history(self, room_id, token, delete_local_events):
+ async def purge_history(
+ self, room_id: str, token: str, delete_local_events: bool
+ ) -> Set[int]:
"""Deletes room history before a certain point
Args:
- room_id (str):
-
- token (str): A topological token to delete events before
-
- delete_local_events (bool):
+ room_id:
+ token: A topological token to delete events before
+ delete_local_events:
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
Returns:
- Deferred[set[int]]: The set of state groups that are referenced by
- deleted events.
+ The set of state groups that are referenced by deleted events.
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"purge_history",
self._purge_history_txn,
room_id,
@@ -62,6 +61,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# event_json
# event_push_actions
# event_reference_hashes
+ # event_relations
# event_search
# event_to_state_groups
# events
@@ -209,6 +209,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"event_edges",
"event_forward_extremities",
"event_reference_hashes",
+ "event_relations",
"event_search",
"rejections",
):
@@ -281,17 +282,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
return referenced_state_groups
- def purge_room(self, room_id):
+ async def purge_room(self, room_id: str) -> List[int]:
"""Deletes all record of a room
Args:
- room_id (str)
+ room_id
Returns:
- Deferred[List[int]]: The list of state groups to delete.
+ The list of state groups to delete.
"""
-
- return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
+ return await self.db_pool.runInteraction(
+ "purge_room", self._purge_room_txn, room_id
+ )
def _purge_room_txn(self, txn, room_id):
# First we fetch all the state groups that should be deleted, before
@@ -361,7 +363,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"event_push_summary",
"pusher_throttle",
"group_summary_rooms",
- "local_invites",
"room_account_data",
"room_tags",
"local_current_membership",
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index ef8f40959f..0de802a86b 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -16,40 +16,37 @@
import abc
import logging
-from typing import Union
-
-from canonicaljson import json
-
-from twisted.internet import defer
+from typing import List, Tuple, Union
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.data_stores.main.pusher import PusherWorkerStore
-from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
-from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
-from synapse.storage.database import Database
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.pusher import PusherWorkerStore
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import ChainedIdGenerator
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.util import json_encoder
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-def _load_rules(rawrules, enabled_map):
+def _load_rules(rawrules, enabled_map, use_new_defaults=False):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
- rule["conditions"] = json.loads(rawrule["conditions"])
- rule["actions"] = json.loads(rawrule["actions"])
+ rule["conditions"] = db_to_json(rawrule["conditions"])
+ rule["actions"] = db_to_json(rawrule["actions"])
rule["default"] = False
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
- rules = list(list_with_base_rules(ruleslist))
+ rules = list(list_with_base_rules(ruleslist, use_new_defaults))
for i, rule in enumerate(rules):
rule_id = rule["rule_id"]
@@ -79,19 +76,19 @@ class PushRulesWorkerStore(
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen = ChainedIdGenerator(
- self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
- ) # type: Union[ChainedIdGenerator, SlavedIdTracker]
+ self._push_rules_stream_id_gen = StreamIdGenerator(
+ db_conn, "push_rules_stream", "stream_id"
+ ) # type: Union[StreamIdGenerator, SlavedIdTracker]
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
)
- push_rules_prefill, push_rules_id = self.db.get_cache_dict(
+ push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
db_conn,
"push_rules_stream",
entity_column="user_id",
@@ -105,6 +102,8 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill,
)
+ self._users_new_default_push_rules = hs.config.users_new_default_push_rules
+
@abc.abstractmethod
def get_max_push_rules_stream_id(self):
"""Get the position of the push rules stream.
@@ -114,9 +113,9 @@ class PushRulesWorkerStore(
"""
raise NotImplementedError()
- @cachedInlineCallbacks(max_entries=5000)
- def get_push_rules_for_user(self, user_id):
- rows = yield self.db.simple_select_list(
+ @cached(max_entries=5000)
+ async def get_push_rules_for_user(self, user_id):
+ rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
retcols=(
@@ -132,15 +131,15 @@ class PushRulesWorkerStore(
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
- enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
+ enabled_map = await self.get_push_rules_enabled_for_user(user_id)
- rules = _load_rules(rows, enabled_map)
+ use_new_defaults = user_id in self._users_new_default_push_rules
- return rules
+ return _load_rules(rows, enabled_map, use_new_defaults)
- @cachedInlineCallbacks(max_entries=5000)
- def get_push_rules_enabled_for_user(self, user_id):
- results = yield self.db.simple_select_list(
+ @cached(max_entries=5000)
+ async def get_push_rules_enabled_for_user(self, user_id):
+ results = await self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
retcols=("user_name", "rule_id", "enabled"),
@@ -148,9 +147,11 @@ class PushRulesWorkerStore(
)
return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
- def have_push_rules_changed_for_user(self, user_id, last_id):
+ async def have_push_rules_changed_for_user(
+ self, user_id: str, last_id: int
+ ) -> bool:
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
- return defer.succeed(False)
+ return False
else:
def have_push_rules_changed_txn(txn):
@@ -162,23 +163,20 @@ class PushRulesWorkerStore(
(count,) = txn.fetchone()
return bool(count)
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
@cachedList(
- cached_method_name="get_push_rules_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
+ cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
)
- def bulk_get_push_rules(self, user_ids):
+ async def bulk_get_push_rules(self, user_ids):
if not user_ids:
return {}
results = {user_id: [] for user_id in user_ids}
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
@@ -191,21 +189,26 @@ class PushRulesWorkerStore(
for row in rows:
results.setdefault(row["user_name"], []).append(row)
- enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
+ enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
- results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
+ use_new_defaults = user_id in self._users_new_default_push_rules
+
+ results[user_id] = _load_rules(
+ rules, enabled_map_by_user.get(user_id, {}), use_new_defaults,
+ )
return results
- @defer.inlineCallbacks
- def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
+ async def copy_push_rule_from_room_to_room(
+ self, new_room_id: str, user_id: str, rule: dict
+ ) -> None:
"""Copy a single push rule from one room to another for a specific user.
Args:
- new_room_id (str): ID of the new room.
- user_id (str): ID of user the push rule belongs to.
- rule (Dict): A push rule.
+ new_room_id: ID of the new room.
+ user_id : ID of user the push rule belongs to.
+ rule: A push rule.
"""
# Create new rule id
rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
@@ -217,7 +220,7 @@ class PushRulesWorkerStore(
condition["pattern"] = new_room_id
# Add the rule for the new room
- yield self.add_push_rule(
+ await self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
priority_class=rule["priority_class"],
@@ -225,20 +228,19 @@ class PushRulesWorkerStore(
actions=rule["actions"],
)
- @defer.inlineCallbacks
- def copy_push_rules_from_room_to_room_for_user(
- self, old_room_id, new_room_id, user_id
- ):
+ async def copy_push_rules_from_room_to_room_for_user(
+ self, old_room_id: str, new_room_id: str, user_id: str
+ ) -> None:
"""Copy all of the push rules from one room to another for a specific
user.
Args:
- old_room_id (str): ID of the old room.
- new_room_id (str): ID of the new room.
- user_id (str): ID of user to copy push rules for.
+ old_room_id: ID of the old room.
+ new_room_id: ID of the new room.
+ user_id: ID of user to copy push rules for.
"""
# Retrieve push rules for this user
- user_push_rules = yield self.get_push_rules_for_user(user_id)
+ user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
for rule in user_push_rules:
@@ -247,96 +249,20 @@ class PushRulesWorkerStore(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
):
- yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
-
- @defer.inlineCallbacks
- def bulk_get_push_rules_for_room(self, event, context):
- state_group = context.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- current_state_ids = yield context.get_current_state_ids()
- result = yield self._bulk_get_push_rules_for_room(
- event.room_id, state_group, current_state_ids, event=event
- )
- return result
-
- @cachedInlineCallbacks(num_args=2, cache_context=True)
- def _bulk_get_push_rules_for_room(
- self, room_id, state_group, current_state_ids, cache_context, event=None
- ):
- # We don't use `state_group`, its there so that we can cache based
- # on it. However, its important that its never None, since two current_state's
- # with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
- assert state_group is not None
-
- # We also will want to generate notifs for other people in the room so
- # their unread countss are correct in the event stream, but to avoid
- # generating them for bot / AS users etc, we only do so for people who've
- # sent a read receipt into the room.
-
- users_in_room = yield self._get_joined_users_from_context(
- room_id,
- state_group,
- current_state_ids,
- on_invalidate=cache_context.invalidate,
- event=event,
- )
-
- # We ignore app service users for now. This is so that we don't fill
- # up the `get_if_users_have_pushers` cache with AS entries that we
- # know don't have pushers, nor even read receipts.
- local_users_in_room = {
- u
- for u in users_in_room
- if self.hs.is_mine_id(u)
- and not self.get_if_app_services_interested_in_user(u)
- }
-
- # users in the room who have pushers need to get push rules run because
- # that's how their pushers work
- if_users_with_pushers = yield self.get_if_users_have_pushers(
- local_users_in_room, on_invalidate=cache_context.invalidate
- )
- user_ids = {
- uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
- }
-
- users_with_receipts = yield self.get_users_with_read_receipts_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
-
- # any users with pushers must be ours: they have pushers
- for uid in users_with_receipts:
- if uid in local_users_in_room:
- user_ids.add(uid)
-
- rules_by_user = yield self.bulk_get_push_rules(
- user_ids, on_invalidate=cache_context.invalidate
- )
-
- rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
-
- return rules_by_user
+ await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
@cachedList(
cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",
num_args=1,
- inlineCallbacks=True,
)
- def bulk_get_push_rules_enabled(self, user_ids):
+ async def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
return {}
results = {user_id: {} for user_id in user_ids}
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
@@ -348,30 +274,59 @@ class PushRulesWorkerStore(
results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
return results
- def get_all_push_rule_updates(self, last_id, current_id, limit):
- """Get all the push rules changes that have happend on the server"""
+ async def get_all_push_rule_updates(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for push_rules replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+
if last_id == current_id:
- return defer.succeed([])
+ return [], current_id, False
def get_all_push_rule_updates_txn(txn):
- sql = (
- "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
- " op, priority_class, priority, conditions, actions"
- " FROM push_rules_stream"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC LIMIT ?"
- )
+ sql = """
+ SELECT stream_id, user_id
+ FROM push_rules_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
txn.execute(sql, (last_id, current_id, limit))
- return txn.fetchall()
+ updates = [(stream_id, (user_id,)) for stream_id, user_id in txn]
- return self.db.runInteraction(
+ limited = False
+ upper_bound = current_id
+ if len(updates) == limit:
+ limited = True
+ upper_bound = updates[-1][0]
+
+ return updates, upper_bound, limited
+
+ return await self.db_pool.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)
class PushRuleStore(PushRulesWorkerStore):
- @defer.inlineCallbacks
- def add_push_rule(
+ async def add_push_rule(
self,
user_id,
rule_id,
@@ -380,13 +335,14 @@ class PushRuleStore(PushRulesWorkerStore):
actions,
before=None,
after=None,
- ):
- conditions_json = json.dumps(conditions)
- actions_json = json.dumps(actions)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
+ ) -> None:
+ conditions_json = json_encoder.encode(conditions)
+ actions_json = json_encoder.encode(actions)
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
if before or after:
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
stream_id,
@@ -400,7 +356,7 @@ class PushRuleStore(PushRulesWorkerStore):
after,
)
else:
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
stream_id,
@@ -431,7 +387,7 @@ class PushRuleStore(PushRulesWorkerStore):
relative_to_rule = before or after
- res = self.db.simple_select_one_txn(
+ res = self.db_pool.simple_select_one_txn(
txn,
table="push_rules",
keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
@@ -554,7 +510,7 @@ class PushRuleStore(PushRulesWorkerStore):
# We didn't update a row with the given rule_id so insert one
push_rule_id = self._push_rule_id_gen.get_next()
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="push_rules",
values={
@@ -584,20 +540,19 @@ class PushRuleStore(PushRulesWorkerStore):
},
)
- @defer.inlineCallbacks
- def delete_push_rule(self, user_id, rule_id):
+ async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
"""
Delete a push rule. Args specify the row to be deleted and can be
any of the columns in the push_rule table, but below are the
standard ones
Args:
- user_id (str): The matrix ID of the push rule owner
- rule_id (str): The rule_id of the rule to be deleted
+ user_id: The matrix ID of the push rule owner
+ rule_id: The rule_id of the rule to be deleted
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
- self.db.simple_delete_one_txn(
+ self.db_pool.simple_delete_one_txn(
txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
)
@@ -605,20 +560,21 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.db.runInteraction(
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
+ await self.db_pool.runInteraction(
"delete_push_rule",
delete_push_rule_txn,
stream_id,
event_stream_ordering,
)
- @defer.inlineCallbacks
- def set_push_rule_enabled(self, user_id, rule_id, enabled):
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.db.runInteraction(
+ async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
+ await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
stream_id,
@@ -632,7 +588,7 @@ class PushRuleStore(PushRulesWorkerStore):
self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
):
new_id = self._push_rules_enable_id_gen.get_next()
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
"push_rules_enable",
{"user_name": user_id, "rule_id": rule_id},
@@ -649,9 +605,10 @@ class PushRuleStore(PushRulesWorkerStore):
op="ENABLE" if enabled else "DISABLE",
)
- @defer.inlineCallbacks
- def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
- actions_json = json.dumps(actions)
+ async def set_push_rule_actions(
+ self, user_id, rule_id, actions, is_default_rule
+ ) -> None:
+ actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
if is_default_rule:
@@ -672,7 +629,7 @@ class PushRuleStore(PushRulesWorkerStore):
update_stream=False,
)
else:
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
"push_rules",
{"user_name": user_id, "rule_id": rule_id},
@@ -689,9 +646,10 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.db.runInteraction(
+ with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
+ await self.db_pool.runInteraction(
"set_push_rule_actions",
set_push_rule_actions_txn,
stream_id,
@@ -711,7 +669,7 @@ class PushRuleStore(PushRulesWorkerStore):
if data is not None:
values.update(data)
- self.db.simple_insert_txn(txn, "push_rules_stream", values=values)
+ self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
@@ -719,11 +677,5 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
- def get_push_rules_stream_token(self):
- """Get the position of the push rules stream.
- Returns a pair of a stream id for the push_rules stream and the
- room stream ordering it corresponds to."""
- return self._push_rules_stream_id_gen.get_current_token()
-
def get_max_push_rules_stream_id(self):
- return self.get_push_rules_stream_token()[0]
+ return self._push_rules_stream_id_gen.get_current_token()
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 547b9d69cb..c388468273 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -15,14 +15,12 @@
# limitations under the License.
import logging
-from typing import Iterable, Iterator
+from typing import Iterable, Iterator, List, Tuple
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
-from twisted.internet import defer
-
-from synapse.storage._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.util.caches.descriptors import cached, cachedList
logger = logging.getLogger(__name__)
@@ -34,23 +32,22 @@ class PusherWorkerStore(SQLBaseStore):
Drops any rows whose data cannot be decoded
"""
for r in rows:
- dataJson = r["data"]
+ data_json = r["data"]
try:
- r["data"] = json.loads(dataJson)
+ r["data"] = db_to_json(data_json)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
r["id"],
- dataJson,
+ data_json,
e.args[0],
)
continue
yield r
- @defer.inlineCallbacks
- def user_has_pusher(self, user_id):
- ret = yield self.db.simple_select_one_onecol(
+ async def user_has_pusher(self, user_id):
+ ret = await self.db_pool.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
@@ -61,9 +58,8 @@ class PusherWorkerStore(SQLBaseStore):
def get_pushers_by_user_id(self, user_id):
return self.get_pushers_by({"user_name": user_id})
- @defer.inlineCallbacks
- def get_pushers_by(self, keyvalues):
- ret = yield self.db.simple_select_list(
+ async def get_pushers_by(self, keyvalues):
+ ret = await self.db_pool.simple_select_list(
"pushers",
keyvalues,
[
@@ -87,104 +83,91 @@ class PusherWorkerStore(SQLBaseStore):
)
return self._decode_pushers_rows(ret)
- @defer.inlineCallbacks
- def get_all_pushers(self):
+ async def get_all_pushers(self):
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
- rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
- return rows
+ return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
- def get_all_updated_pushers(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed(([], []))
-
- def get_all_updated_pushers_txn(txn):
- sql = (
- "SELECT id, user_name, access_token, profile_tag, kind,"
- " app_id, app_display_name, device_display_name, pushkey, ts,"
- " lang, data"
- " FROM pushers"
- " WHERE ? < id AND id <= ?"
- " ORDER BY id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- updated = txn.fetchall()
+ async def get_all_updated_pushers_rows(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for pushers replication stream.
- sql = (
- "SELECT stream_id, user_id, app_id, pushkey"
- " FROM deleted_pushers"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- deleted = txn.fetchall()
-
- return updated, deleted
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
- return self.db.runInteraction(
- "get_all_updated_pushers", get_all_updated_pushers_txn
- )
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
- def get_all_updated_pushers_rows(self, last_id, current_id, limit):
- """Get all the pushers that have changed between the given tokens.
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
- Returns:
- Deferred(list(tuple)): each tuple consists of:
- stream_id (str)
- user_id (str)
- app_id (str)
- pushkey (str)
- was_deleted (bool): whether the pusher was added/updated (False)
- or deleted (True)
+ The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
- return defer.succeed([])
+ return [], current_id, False
def get_all_updated_pushers_rows_txn(txn):
- sql = (
- "SELECT id, user_name, app_id, pushkey"
- " FROM pushers"
- " WHERE ? < id AND id <= ?"
- " ORDER BY id ASC LIMIT ?"
- )
+ sql = """
+ SELECT id, user_name, app_id, pushkey
+ FROM pushers
+ WHERE ? < id AND id <= ?
+ ORDER BY id ASC LIMIT ?
+ """
txn.execute(sql, (last_id, current_id, limit))
- results = [list(row) + [False] for row in txn]
-
- sql = (
- "SELECT stream_id, user_id, app_id, pushkey"
- " FROM deleted_pushers"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC LIMIT ?"
- )
+ updates = [
+ (stream_id, (user_name, app_id, pushkey, False))
+ for stream_id, user_name, app_id, pushkey in txn
+ ]
+
+ sql = """
+ SELECT stream_id, user_id, app_id, pushkey
+ FROM deleted_pushers
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC LIMIT ?
+ """
txn.execute(sql, (last_id, current_id, limit))
+ updates.extend(
+ (stream_id, (user_name, app_id, pushkey, True))
+ for stream_id, user_name, app_id, pushkey in txn
+ )
+
+ updates.sort() # Sort so that they're ordered by stream id
- results.extend(list(row) + [True] for row in txn)
- results.sort() # Sort so that they're ordered by stream id
+ limited = False
+ upper_bound = current_id
+ if len(updates) >= limit:
+ limited = True
+ upper_bound = updates[-1][0]
- return results
+ return updates, upper_bound, limited
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
- @cachedInlineCallbacks(num_args=1, max_entries=15000)
- def get_if_user_has_pusher(self, user_id):
+ @cached(num_args=1, max_entries=15000)
+ async def get_if_user_has_pusher(self, user_id):
# This only exists for the cachedList decorator
raise NotImplementedError()
@cachedList(
- cached_method_name="get_if_user_has_pusher",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
+ cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
)
- def get_if_users_have_pushers(self, user_ids):
- rows = yield self.db.simple_select_many_batch(
+ async def get_if_users_have_pushers(self, user_ids):
+ rows = await self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
iterable=user_ids,
@@ -197,34 +180,38 @@ class PusherWorkerStore(SQLBaseStore):
return result
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering(
+ async def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
- ):
- yield self.db.simple_update_one(
+ ) -> None:
+ await self.db_pool.simple_update_one(
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
{"last_stream_ordering": last_stream_ordering},
desc="update_pusher_last_stream_ordering",
)
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering_and_success(
- self, app_id, pushkey, user_id, last_stream_ordering, last_success
- ):
+ async def update_pusher_last_stream_ordering_and_success(
+ self,
+ app_id: str,
+ pushkey: str,
+ user_id: str,
+ last_stream_ordering: int,
+ last_success: int,
+ ) -> bool:
"""Update the last stream ordering position we've processed up to for
the given pusher.
Args:
- app_id (str)
- pushkey (str)
- last_stream_ordering (int)
- last_success (int)
+ app_id
+ pushkey
+ user_id
+ last_stream_ordering
+ last_success
Returns:
- Deferred[bool]: True if the pusher still exists; False if it has been deleted.
+ True if the pusher still exists; False if it has been deleted.
"""
- updated = yield self.db.simple_update(
+ updated = await self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={
@@ -236,18 +223,18 @@ class PusherWorkerStore(SQLBaseStore):
return bool(updated)
- @defer.inlineCallbacks
- def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
- yield self.db.simple_update(
+ async def update_pusher_failing_since(
+ self, app_id, pushkey, user_id, failing_since
+ ) -> None:
+ await self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={"failing_since": failing_since},
desc="update_pusher_failing_since",
)
- @defer.inlineCallbacks
- def get_throttle_params_by_room(self, pusher_id):
- res = yield self.db.simple_select_list(
+ async def get_throttle_params_by_room(self, pusher_id):
+ res = await self.db_pool.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
@@ -263,11 +250,10 @@ class PusherWorkerStore(SQLBaseStore):
return params_by_room
- @defer.inlineCallbacks
- def set_throttle_params(self, pusher_id, room_id, params):
+ async def set_throttle_params(self, pusher_id, room_id, params) -> None:
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
- yield self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,
@@ -280,8 +266,7 @@ class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
- @defer.inlineCallbacks
- def add_pusher(
+ async def add_pusher(
self,
user_id,
access_token,
@@ -295,11 +280,11 @@ class PusherStore(PusherWorkerStore):
data,
last_stream_ordering,
profile_tag="",
- ):
- with self._pushers_id_gen.get_next() as stream_id:
+ ) -> None:
+ with await self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
- yield self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
@@ -324,21 +309,22 @@ class PusherStore(PusherWorkerStore):
if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
self.get_if_user_has_pusher,
(user_id,),
)
- @defer.inlineCallbacks
- def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
+ async def delete_pusher_by_app_id_pushkey_user_id(
+ self, app_id, pushkey, user_id
+ ) -> None:
def delete_pusher_txn(txn, stream_id):
self._invalidate_cache_and_stream(
txn, self.get_if_user_has_pusher, (user_id,)
)
- self.db.simple_delete_one_txn(
+ self.db_pool.simple_delete_one_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@@ -347,7 +333,7 @@ class PusherStore(PusherWorkerStore):
# it's possible for us to end up with duplicate rows for
# (app_id, pushkey, user_id) at different stream_ids, but that
# doesn't really matter.
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="deleted_pushers",
values={
@@ -358,5 +344,7 @@ class PusherStore(PusherWorkerStore):
},
)
- with self._pushers_id_gen.get_next() as stream_id:
- yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
+ with await self._pushers_id_gen.get_next() as stream_id:
+ await self.db_pool.runInteraction(
+ "delete_pusher", delete_pusher_txn, stream_id
+ )
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/databases/main/receipts.py
index cebdcd409f..4a0d5a320e 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -16,15 +16,16 @@
import abc
import logging
-
-from canonicaljson import json
+from typing import Any, Dict, List, Optional, Tuple
from twisted.internet import defer
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util import json_encoder
+from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -39,7 +40,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
@@ -55,14 +56,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
raise NotImplementedError()
- @cachedInlineCallbacks()
- def get_users_with_read_receipts_in_room(self, room_id):
- receipts = yield self.get_receipts_for_room(room_id, "m.read")
+ @cached()
+ async def get_users_with_read_receipts_in_room(self, room_id):
+ receipts = await self.get_receipts_for_room(room_id, "m.read")
return {r["user_id"] for r in receipts}
@cached(num_args=2)
- def get_receipts_for_room(self, room_id, receipt_type):
- return self.db.simple_select_list(
+ async def get_receipts_for_room(
+ self, room_id: str, receipt_type: str
+ ) -> List[Dict[str, Any]]:
+ return await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
retcols=("user_id", "event_id"),
@@ -70,8 +73,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
@cached(num_args=3)
- def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
- return self.db.simple_select_one_onecol(
+ async def get_last_receipt_event_id_for_user(
+ self, user_id: str, room_id: str, receipt_type: str
+ ) -> Optional[str]:
+ return await self.db_pool.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
@@ -83,9 +88,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True,
)
- @cachedInlineCallbacks(num_args=2)
- def get_receipts_for_user(self, user_id, receipt_type):
- rows = yield self.db.simple_select_list(
+ @cached(num_args=2)
+ async def get_receipts_for_user(self, user_id, receipt_type):
+ rows = await self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
@@ -94,8 +99,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return {row["room_id"]: row["event_id"] for row in rows}
- @defer.inlineCallbacks
- def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
+ async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
def f(txn):
sql = (
"SELECT rl.room_id, rl.event_id,"
@@ -109,7 +113,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return txn.fetchall()
- rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f)
+ rows = await self.db_pool.runInteraction(
+ "get_receipts_for_user_with_orderings", f
+ )
return {
row[0]: {
"event_id": row[1],
@@ -119,56 +125,61 @@ class ReceiptsWorkerStore(SQLBaseStore):
for row in rows
}
- @defer.inlineCallbacks
- def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def get_linearized_receipts_for_rooms(
+ self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""Get receipts for multiple rooms for sending to clients.
Args:
- room_ids (list): List of room_ids.
- to_key (int): Max stream id to fetch receipts upto.
- from_key (int): Min stream id to fetch receipts from. None fetches
+ room_id: List of room_ids.
+ to_key: Max stream id to fetch receipts upto.
+ from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
- list: A list of receipts.
+ A list of receipts.
"""
room_ids = set(room_ids)
if from_key is not None:
# Only ask the database about rooms where there have been new
# receipts added since `from_key`
- room_ids = yield self._receipts_stream_cache.get_entities_changed(
+ room_ids = self._receipts_stream_cache.get_entities_changed(
room_ids, from_key
)
- results = yield self._get_linearized_receipts_for_rooms(
+ results = await self._get_linearized_receipts_for_rooms(
room_ids, to_key, from_key=from_key
)
return [ev for res in results.values() for ev in res]
- def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ async def get_linearized_receipts_for_room(
+ self, room_id: str, to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""Get receipts for a single room for sending to clients.
Args:
- room_ids (str): The room id.
- to_key (int): Max stream id to fetch receipts upto.
- from_key (int): Min stream id to fetch receipts from. None fetches
+ room_ids: The room id.
+ to_key: Max stream id to fetch receipts upto.
+ from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
- Deferred[list]: A list of receipts.
+ A list of receipts.
"""
if from_key is not None:
# Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op.
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
- defer.succeed([])
+ return []
- return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
+ return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
- @cachedInlineCallbacks(num_args=3, tree=True)
- def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+ @cached(num_args=3, tree=True)
+ async def _get_linearized_receipts_for_room(
+ self, room_id: str, to_key: int, from_key: Optional[int] = None
+ ) -> List[dict]:
"""See get_linearized_receipts_for_room
"""
@@ -188,11 +199,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (room_id, to_key))
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
return rows
- rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f)
+ rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
return []
@@ -201,7 +212,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"]
- ] = json.loads(row["data"])
+ ] = db_to_json(row["data"])
return [{"type": "m.receipt", "room_id": room_id, "content": content}]
@@ -209,9 +220,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids",
num_args=3,
- inlineCallbacks=True,
)
- def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+ async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
return {}
@@ -238,9 +248,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [to_key] + list(args))
- return self.db.cursor_to_dict(txn)
+ return self.db_pool.cursor_to_dict(txn)
- txn_results = yield self.db.runInteraction(
+ txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
@@ -258,7 +268,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
- receipt_type[row["user_id"]] = json.loads(row["data"])
+ receipt_type[row["user_id"]] = db_to_json(row["data"])
results = {
room_id: [results[room_id]] if room_id in results else []
@@ -266,31 +276,86 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
return results
- def get_all_updated_receipts(self, last_id, current_id, limit=None):
+ async def get_users_sent_receipts_between(
+ self, last_id: int, current_id: int
+ ) -> List[str]:
+ """Get all users who sent receipts between `last_id` exclusive and
+ `current_id` inclusive.
+
+ Returns:
+ The list of users.
+ """
+
if last_id == current_id:
return defer.succeed([])
+ def _get_users_sent_receipts_between_txn(txn):
+ sql = """
+ SELECT DISTINCT user_id FROM receipts_linearized
+ WHERE ? < stream_id AND stream_id <= ?
+ """
+ txn.execute(sql, (last_id, current_id))
+
+ return [r[0] for r in txn]
+
+ return await self.db_pool.runInteraction(
+ "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
+ )
+
+ async def get_all_updated_receipts(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
+ """Get updates for receipts replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+
+ if last_id == current_id:
+ return [], current_id, False
+
def get_all_updated_receipts_txn(txn):
- sql = (
- "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
- " FROM receipts_linearized"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
- )
- args = [last_id, current_id]
- if limit is not None:
- sql += " LIMIT ?"
- args.append(limit)
- txn.execute(sql, args)
+ sql = """
+ SELECT stream_id, room_id, receipt_type, user_id, event_id, data
+ FROM receipts_linearized
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, current_id, limit))
+
+ updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
+
+ limited = False
+ upper_bound = current_id
+
+ if len(updates) == limit:
+ limited = True
+ upper_bound = updates[-1][0]
- return [r[0:5] + (json.loads(r[5]),) for r in txn]
+ return updates, upper_bound, limited
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)
def _invalidate_get_users_with_receipts_in_room(
- self, room_id, receipt_type, user_id
+ self, room_id: str, receipt_type: str, user_id: str
):
if receipt_type != "m.read":
return
@@ -300,10 +365,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
room_id, None, update_metrics=False
)
- # first handle the Deferred case
- if isinstance(res, defer.Deferred):
- if res.called:
- res = res.result
+ # first handle the ObservableDeferred case
+ if isinstance(res, ObservableDeferred):
+ if res.has_called():
+ res = res.get_result()
else:
res = None
@@ -316,7 +381,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
class ReceiptsStore(ReceiptsWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = StreamIdGenerator(
@@ -338,7 +403,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
"""
- res = self.db.simple_select_one_txn(
+ res = self.db_pool.simple_select_one_txn(
txn,
table="events",
retcols=["stream_ordering", "received_ts"],
@@ -391,7 +456,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
(user_id, room_id, receipt_type),
)
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="receipts_linearized",
keyvalues={
@@ -402,7 +467,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
values={
"stream_id": stream_id,
"event_id": event_id,
- "data": json.dumps(data),
+ "data": json_encoder.encode(data),
},
# receipts_linearized has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
@@ -416,15 +481,21 @@ class ReceiptsStore(ReceiptsWorkerStore):
return rx_ts
- @defer.inlineCallbacks
- def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
+ async def insert_receipt(
+ self,
+ room_id: str,
+ receipt_type: str,
+ user_id: str,
+ event_ids: List[str],
+ data: dict,
+ ) -> Optional[Tuple[int, int]]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
representations.
"""
if not event_ids:
- return
+ return None
if len(event_ids) == 1:
linearized_event_id = event_ids[0]
@@ -451,13 +522,12 @@ class ReceiptsStore(ReceiptsWorkerStore):
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
- linearized_event_id = yield self.db.runInteraction(
+ linearized_event_id = await self.db_pool.runInteraction(
"insert_receipt_conv", graph_to_linear
)
- stream_id_manager = self._receipts_id_gen.get_next()
- with stream_id_manager as stream_id:
- event_ts = yield self.db.runInteraction(
+ with await self._receipts_id_gen.get_next() as stream_id:
+ event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id,
@@ -479,14 +549,16 @@ class ReceiptsStore(ReceiptsWorkerStore):
now - event_ts,
)
- yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
+ await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
max_persisted_id = self._receipts_id_gen.get_current_token()
return stream_id, max_persisted_id
- def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
- return self.db.runInteraction(
+ async def insert_graph_receipt(
+ self, room_id, receipt_type, user_id, event_ids, data
+ ):
+ return await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
@@ -512,7 +584,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="receipts_graph",
keyvalues={
@@ -521,14 +593,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
"user_id": user_id,
},
)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="receipts_graph",
values={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
- "event_ids": json.dumps(event_ids),
- "data": json.dumps(data),
+ "event_ids": json_encoder.encode(event_ids),
+ "data": json_encoder.encode(data),
},
)
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/databases/main/registration.py
index 9768981891..01f20c03c2 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,20 +17,17 @@
import logging
import re
-from typing import Optional
-
-from six import iterkeys
-
-from twisted.internet import defer
-from twisted.internet.defer import Deferred
+from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.types import Cursor
+from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
@@ -38,15 +35,19 @@ logger = logging.getLogger(__name__)
class RegistrationWorkerStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
self.config = hs.config
self.clock = hs.get_clock()
+ self._user_id_seq = build_sequence_generator(
+ database.engine, find_max_generated_user_id_localpart, "user_id_seq",
+ )
+
@cached()
- def get_user_by_id(self, user_id):
- return self.db.simple_select_one(
+ async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
@@ -65,19 +66,15 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_user_by_id",
)
- @defer.inlineCallbacks
- def is_trial_user(self, user_id):
+ async def is_trial_user(self, user_id: str) -> bool:
"""Checks if user is in the "trial" period, i.e. within the first
N days of registration defined by `mau_trial_days` config
Args:
- user_id (str)
-
- Returns:
- Deferred[bool]
+ user_id: The user to check for trial status.
"""
- info = yield self.get_user_by_id(user_id)
+ info = await self.get_user_by_id(user_id)
if not info:
return False
@@ -87,60 +84,61 @@ class RegistrationWorkerStore(SQLBaseStore):
return is_trial
@cached()
- def get_user_by_access_token(self, token):
+ async def get_user_by_access_token(self, token: str) -> Optional[dict]:
"""Get a user from the given access token.
Args:
- token (str): The access token of a user.
+ token: The access token of a user.
Returns:
- defer.Deferred: None, if the token did not match, otherwise dict
- including the keys `name`, `is_guest`, `device_id`, `token_id`,
- `valid_until_ms`.
+ None, if the token did not match, otherwise dict
+ including the keys `name`, `is_guest`, `device_id`, `token_id`,
+ `valid_until_ms`.
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
)
- @cachedInlineCallbacks()
- def get_expiration_ts_for_user(self, user_id):
+ @cached()
+ async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]:
"""Get the expiration timestamp for the account bearing a given user ID.
Args:
- user_id (str): The ID of the user.
+ user_id: The ID of the user.
Returns:
- defer.Deferred: None, if the account has no expiration timestamp,
- otherwise int representation of the timestamp (as a number of
- milliseconds since epoch).
+ None, if the account has no expiration timestamp, otherwise int
+ representation of the timestamp (as a number of milliseconds since epoch).
"""
- res = yield self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="expiration_ts_ms",
allow_none=True,
desc="get_expiration_ts_for_user",
)
- return res
- @defer.inlineCallbacks
- def set_account_validity_for_user(
- self, user_id, expiration_ts, email_sent, renewal_token=None
- ):
+ async def set_account_validity_for_user(
+ self,
+ user_id: str,
+ expiration_ts: int,
+ email_sent: bool,
+ renewal_token: Optional[str] = None,
+ ) -> None:
"""Updates the account validity properties of the given account, with the
given values.
Args:
- user_id (str): ID of the account to update properties for.
- expiration_ts (int): New expiration date, as a timestamp in milliseconds
+ user_id: ID of the account to update properties for.
+ expiration_ts: New expiration date, as a timestamp in milliseconds
since epoch.
- email_sent (bool): True means a renewal email has been sent for this
- account and there's no need to send another one for the current validity
+ email_sent: True means a renewal email has been sent for this account
+ and there's no need to send another one for the current validity
period.
- renewal_token (str): Renewal token the user can use to extend the validity
+ renewal_token: Renewal token the user can use to extend the validity
of their account. Defaults to no token.
"""
def set_account_validity_for_user_txn(txn):
- self.db.simple_update_txn(
+ self.db_pool.simple_update_txn(
txn=txn,
table="account_validity",
keyvalues={"user_id": user_id},
@@ -154,75 +152,69 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_expiration_ts_for_user, (user_id,)
)
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"set_account_validity_for_user", set_account_validity_for_user_txn
)
- @defer.inlineCallbacks
- def set_renewal_token_for_user(self, user_id, renewal_token):
+ async def set_renewal_token_for_user(
+ self, user_id: str, renewal_token: str
+ ) -> None:
"""Defines a renewal token for a given user.
Args:
- user_id (str): ID of the user to set the renewal token for.
- renewal_token (str): Random unique string that will be used to renew the
+ user_id: ID of the user to set the renewal token for.
+ renewal_token: Random unique string that will be used to renew the
user's account.
Raises:
StoreError: The provided token is already set for another user.
"""
- yield self.db.simple_update_one(
+ await self.db_pool.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"renewal_token": renewal_token},
desc="set_renewal_token_for_user",
)
- @defer.inlineCallbacks
- def get_user_from_renewal_token(self, renewal_token):
+ async def get_user_from_renewal_token(self, renewal_token: str) -> str:
"""Get a user ID from a renewal token.
Args:
- renewal_token (str): The renewal token to perform the lookup with.
+ renewal_token: The renewal token to perform the lookup with.
Returns:
- defer.Deferred[str]: The ID of the user to which the token belongs.
+ The ID of the user to which the token belongs.
"""
- res = yield self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcol="user_id",
desc="get_user_from_renewal_token",
)
- return res
-
- @defer.inlineCallbacks
- def get_renewal_token_for_user(self, user_id):
+ async def get_renewal_token_for_user(self, user_id: str) -> str:
"""Get the renewal token associated with a given user ID.
Args:
- user_id (str): The user ID to lookup a token for.
+ user_id: The user ID to lookup a token for.
Returns:
- defer.Deferred[str]: The renewal token associated with this user ID.
+ The renewal token associated with this user ID.
"""
- res = yield self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="renewal_token",
desc="get_renewal_token_for_user",
)
- return res
-
- @defer.inlineCallbacks
- def get_users_expiring_soon(self):
+ async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
"""Selects users whose account will expire in the [now, now + renew_at] time
window (see configuration for account_validity for information on what renew_at
refers to).
Returns:
- Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
+ A list of dictionaries mapping user ID to expiration time (in milliseconds).
"""
def select_users_txn(txn, now_ms, renew_at):
@@ -232,58 +224,54 @@ class RegistrationWorkerStore(SQLBaseStore):
)
values = [False, now_ms, renew_at]
txn.execute(sql, values)
- return self.db.cursor_to_dict(txn)
+ return self.db_pool.cursor_to_dict(txn)
- res = yield self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_expiring_soon",
select_users_txn,
self.clock.time_msec(),
self.config.account_validity.renew_at,
)
- return res
-
- @defer.inlineCallbacks
- def set_renewal_mail_status(self, user_id, email_sent):
+ async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
"""Sets or unsets the flag that indicates whether a renewal email has been sent
to the user (and the user hasn't renewed their account yet).
Args:
- user_id (str): ID of the user to set/unset the flag for.
- email_sent (bool): Flag which indicates whether a renewal email has been sent
+ user_id: ID of the user to set/unset the flag for.
+ email_sent: Flag which indicates whether a renewal email has been sent
to this user.
"""
- yield self.db.simple_update_one(
+ await self.db_pool.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"email_sent": email_sent},
desc="set_renewal_mail_status",
)
- @defer.inlineCallbacks
- def delete_account_validity_for_user(self, user_id):
+ async def delete_account_validity_for_user(self, user_id: str) -> None:
"""Deletes the entry for the given user in the account validity table, removing
their expiration date and renewal token.
Args:
- user_id (str): ID of the user to remove from the account validity table.
+ user_id: ID of the user to remove from the account validity table.
"""
- yield self.db.simple_delete_one(
+ await self.db_pool.simple_delete_one(
table="account_validity",
keyvalues={"user_id": user_id},
desc="delete_account_validity_for_user",
)
- async def is_server_admin(self, user):
+ async def is_server_admin(self, user: UserID) -> bool:
"""Determines if a user is an admin of this homeserver.
Args:
- user (UserID): user ID of the user to test
+ user: user ID of the user to test
- Returns (bool):
+ Returns:
true iff the user is a server admin, false otherwise.
"""
- res = await self.db.simple_select_one_onecol(
+ res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
@@ -293,28 +281,27 @@ class RegistrationWorkerStore(SQLBaseStore):
return bool(res) if res else False
- def set_server_admin(self, user, admin):
+ async def set_server_admin(self, user: UserID, admin: bool) -> None:
"""Sets whether a user is an admin of this homeserver.
Args:
- user (UserID): user ID of the user to test
- admin (bool): true iff the user is to be a server admin,
- false otherwise.
+ user: user ID of the user to test
+ admin: true iff the user is to be a server admin, false otherwise.
"""
def set_server_admin_txn(txn):
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user.to_string(),)
)
- return self.db.runInteraction("set_server_admin", set_server_admin_txn)
+ await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token):
sql = (
- "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+ "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
" access_tokens.device_id, access_tokens.valid_until_ms"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
@@ -322,43 +309,42 @@ class RegistrationWorkerStore(SQLBaseStore):
)
txn.execute(sql, (token,))
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if rows:
return rows[0]
return None
- @cachedInlineCallbacks()
- def is_real_user(self, user_id):
+ @cached()
+ async def is_real_user(self, user_id: str) -> bool:
"""Determines if the user is a real user, ie does not have a 'user_type'.
Args:
- user_id (str): user id to test
+ user_id: user id to test
Returns:
- Deferred[bool]: True if user 'user_type' is null or empty string
+ True if user 'user_type' is null or empty string
"""
- res = yield self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"is_real_user", self.is_real_user_txn, user_id
)
- return res
@cached()
- def is_support_user(self, user_id):
+ async def is_support_user(self, user_id: str) -> bool:
"""Determines if the user is of type UserTypes.SUPPORT
Args:
- user_id (str): user id to test
+ user_id: user id to test
Returns:
- Deferred[bool]: True if user is of type UserTypes.SUPPORT
+ True if user is of type UserTypes.SUPPORT
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"is_support_user", self.is_support_user_txn, user_id
)
def is_real_user_txn(self, txn, user_id):
- res = self.db.simple_select_one_onecol_txn(
+ res = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@@ -368,7 +354,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return res is None
def is_support_user_txn(self, txn, user_id):
- res = self.db.simple_select_one_onecol_txn(
+ res = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@@ -377,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore):
)
return True if res == UserTypes.SUPPORT else False
- def get_users_by_id_case_insensitive(self, user_id):
+ async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]:
"""Gets users that match user_id case insensitively.
- Returns a mapping of user_id -> password_hash.
+
+ Returns:
+ A mapping of user_id -> password_hash.
"""
def f(txn):
@@ -387,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return dict(txn)
- return self.db.runInteraction("get_users_by_id_case_insensitive", f)
+ return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
@@ -401,7 +389,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
str|None: the mxid of the user, or None if they are not known
"""
- return await self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="user_external_ids",
keyvalues={"auth_provider": auth_provider, "external_id": external_id},
retcol="user_id",
@@ -409,21 +397,19 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_user_by_external_id",
)
- @defer.inlineCallbacks
- def count_all_users(self):
+ async def count_all_users(self):
"""Counts all users registered on the homeserver."""
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
- ret = yield self.db.runInteraction("count_users", _count_users)
- return ret
+ return await self.db_pool.runInteraction("count_users", _count_users)
- def count_daily_user_type(self):
+ async def count_daily_user_type(self) -> Dict[str, int]:
"""
Counts 1) native non guest users
2) native guests users
@@ -452,10 +438,11 @@ class RegistrationWorkerStore(SQLBaseStore):
results[row[0]] = row[1]
return results
- return self.db.runInteraction("count_daily_user_type", _count_daily_user_type)
+ return await self.db_pool.runInteraction(
+ "count_daily_user_type", _count_daily_user_type
+ )
- @defer.inlineCallbacks
- def count_nonbridged_users(self):
+ async def count_nonbridged_users(self):
def _count_users(txn):
txn.execute(
"""
@@ -466,56 +453,31 @@ class RegistrationWorkerStore(SQLBaseStore):
(count,) = txn.fetchone()
return count
- ret = yield self.db.runInteraction("count_users", _count_users)
- return ret
+ return await self.db_pool.runInteraction("count_users", _count_users)
- @defer.inlineCallbacks
- def count_real_users(self):
+ async def count_real_users(self):
"""Counts all users without a special user_type registered on the homeserver."""
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
- ret = yield self.db.runInteraction("count_real_users", _count_users)
- return ret
+ return await self.db_pool.runInteraction("count_real_users", _count_users)
- @defer.inlineCallbacks
- def find_next_generated_user_id_localpart(self):
- """
- Gets the localpart of the next generated user ID.
+ async def generate_user_id(self) -> str:
+ """Generate a suitable localpart for a guest user
- Generated user IDs are integers, so we find the largest integer user ID
- already taken and return that plus one.
+ Returns: a (hopefully) free localpart
"""
-
- def _find_next_generated_user_id(txn):
- # We bound between '@0' and '@a' to avoid pulling the entire table
- # out.
- txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
-
- regex = re.compile(r"^@(\d+):")
-
- max_found = 0
-
- for (user_id,) in txn:
- match = regex.search(user_id)
- if match:
- max_found = max(int(match.group(1)), max_found)
-
- return max_found + 1
-
- return (
- (
- yield self.db.runInteraction(
- "find_next_generated_user_id", _find_next_generated_user_id
- )
- )
+ next_id = await self.db_pool.runInteraction(
+ "generate_user_id", self._user_id_seq.get_next_id_txn
)
+ return str(next_id)
+
async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]:
"""Returns user id from threepid
@@ -526,7 +488,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
The user ID or None if no user id/threepid mapping exists
"""
- user_id = await self.db.runInteraction(
+ user_id = await self.db_pool.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
)
return user_id
@@ -542,7 +504,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
str|None: user id or None if no user id/threepid mapping exists
"""
- ret = self.db.simple_select_one_txn(
+ ret = self.db_pool.simple_select_one_txn(
txn,
"user_threepids",
{"medium": medium, "address": address},
@@ -553,61 +515,57 @@ class RegistrationWorkerStore(SQLBaseStore):
return ret["user_id"]
return None
- @defer.inlineCallbacks
- def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
- yield self.db.simple_upsert(
+ async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+ await self.db_pool.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
- @defer.inlineCallbacks
- def user_get_threepids(self, user_id):
- ret = yield self.db.simple_select_list(
+ async def user_get_threepids(self, user_id):
+ return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
["medium", "address", "validated_at", "added_at"],
"user_get_threepids",
)
- return ret
- def user_delete_threepid(self, user_id, medium, address):
- return self.db.simple_delete(
+ async def user_delete_threepid(self, user_id, medium, address) -> None:
+ await self.db_pool.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
desc="user_delete_threepid",
)
- def user_delete_threepids(self, user_id: str):
+ async def user_delete_threepids(self, user_id: str) -> None:
"""Delete all threepid this user has bound
Args:
user_id: The user id to delete all threepids of
"""
- return self.db.simple_delete(
+ await self.db_pool.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id},
desc="user_delete_threepids",
)
- def add_user_bound_threepid(self, user_id, medium, address, id_server):
+ async def add_user_bound_threepid(
+ self, user_id: str, medium: str, address: str, id_server: str
+ ):
"""The server proxied a bind request to the given identity server on
behalf of the given user. We need to remember this in case the user
asks us to unbind the threepid.
Args:
- user_id (str)
- medium (str)
- address (str)
- id_server (str)
-
- Returns:
- Deferred
+ user_id
+ medium
+ address
+ id_server
"""
# We need to use an upsert, in case they user had already bound the
# threepid
- return self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@@ -620,41 +578,40 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="add_user_bound_threepid",
)
- def user_get_bound_threepids(self, user_id):
+ async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
"""Get the threepids that a user has bound to an identity server through the homeserver
The homeserver remembers where binds to an identity server occurred. Using this
method can retrieve those threepids.
Args:
- user_id (str): The ID of the user to retrieve threepids for
+ user_id: The ID of the user to retrieve threepids for
Returns:
- Deferred[list[dict]]: List of dictionaries containing the following:
+ List of dictionaries containing the following keys:
medium (str): The medium of the threepid (e.g "email")
address (str): The address of the threepid (e.g "bob@example.com")
"""
- return self.db.simple_select_list(
+ return await self.db_pool.simple_select_list(
table="user_threepid_id_server",
keyvalues={"user_id": user_id},
retcols=["medium", "address"],
desc="user_get_bound_threepids",
)
- def remove_user_bound_threepid(self, user_id, medium, address, id_server):
+ async def remove_user_bound_threepid(
+ self, user_id: str, medium: str, address: str, id_server: str
+ ) -> None:
"""The server proxied an unbind request to the given identity server on
behalf of the given user, so we remove the mapping of threepid to
identity server.
Args:
- user_id (str)
- medium (str)
- address (str)
- id_server (str)
-
- Returns:
- Deferred
+ user_id
+ medium
+ address
+ id_server
"""
- return self.db.simple_delete(
+ await self.db_pool.simple_delete(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@@ -665,37 +622,39 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="remove_user_bound_threepid",
)
- def get_id_servers_user_bound(self, user_id, medium, address):
+ async def get_id_servers_user_bound(
+ self, user_id: str, medium: str, address: str
+ ) -> List[str]:
"""Get the list of identity servers that the server proxied bind
requests to for given user and threepid
Args:
- user_id (str)
- medium (str)
- address (str)
+ user_id: The user to query for identity servers.
+ medium: The medium to query for identity servers.
+ address: The address to query for identity servers.
Returns:
- Deferred[list[str]]: Resolves to a list of identity servers
+ A list of identity servers
"""
- return self.db.simple_select_onecol(
+ return await self.db_pool.simple_select_onecol(
table="user_threepid_id_server",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
retcol="id_server",
desc="get_id_servers_user_bound",
)
- @cachedInlineCallbacks()
- def get_user_deactivated_status(self, user_id):
+ @cached()
+ async def get_user_deactivated_status(self, user_id: str) -> bool:
"""Retrieve the value for the `deactivated` property for the provided user.
Args:
- user_id (str): The ID of the user to retrieve the status for.
+ user_id: The ID of the user to retrieve the status for.
Returns:
- defer.Deferred(bool): The requested value.
+ True if the user was deactivated, false if the user is still active.
"""
- res = yield self.db.simple_select_one_onecol(
+ res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="deactivated",
@@ -705,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore):
# Convert the integer into a boolean.
return res == 1
- def get_threepid_validation_session(
- self, medium, client_secret, address=None, sid=None, validated=True
- ):
+ async def get_threepid_validation_session(
+ self,
+ medium: Optional[str],
+ client_secret: str,
+ address: Optional[str] = None,
+ sid: Optional[str] = None,
+ validated: Optional[bool] = True,
+ ) -> Optional[Dict[str, Any]]:
"""Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata
Args:
- medium (str|None): The medium of the 3PID
- address (str|None): The address of the 3PID
- sid (str|None): The ID of the validation session
- client_secret (str): A unique string provided by the client to help identify this
+ medium: The medium of the 3PID
+ client_secret: A unique string provided by the client to help identify this
validation attempt
- validated (bool|None): Whether sessions should be filtered by
+ address: The address of the 3PID
+ sid: The ID of the validation session
+ validated: Whether sessions should be filtered by
whether they have been validated already or not. None to
perform no filtering
Returns:
- Deferred[dict|None]: A dict containing the following:
+ A dict containing the following:
* address - address of the 3pid
* medium - medium of the 3pid
* client_secret - a secret provided by the client for this validation session
@@ -753,7 +717,7 @@ class RegistrationWorkerStore(SQLBaseStore):
last_send_attempt, validated_at
FROM threepid_validation_session WHERE %s
""" % (
- " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),
+ " AND ".join("%s = ?" % k for k in keyvalues.keys()),
)
if validated is not None:
@@ -762,57 +726,57 @@ class RegistrationWorkerStore(SQLBaseStore):
sql += " LIMIT 1"
txn.execute(sql, list(keyvalues.values()))
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if not rows:
return None
return rows[0]
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
)
- def delete_threepid_session(self, session_id):
+ async def delete_threepid_session(self, session_id: str) -> None:
"""Removes a threepid validation session from the database. This can
be done after validation has been performed and whatever action was
waiting on it has been carried out
Args:
- session_id (str): The ID of the session to delete
+ session_id: The ID of the session to delete
"""
def delete_threepid_session_txn(txn):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="threepid_validation_token",
keyvalues={"session_id": session_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
)
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"delete_threepid_session", delete_threepid_session_txn
)
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.clock = hs.get_clock()
self.config = hs.config
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"access_tokens_device_index",
index_name="access_tokens_device_id",
table="access_tokens",
columns=["user_id", "device_id"],
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"users_creation_ts",
index_name="users_creation_ts",
table="users",
@@ -822,18 +786,19 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
# we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just
# clear the background update.
- self.db.updates.register_noop_background_update("refresh_tokens_device_index")
+ self.db_pool.updates.register_noop_background_update(
+ "refresh_tokens_device_index"
+ )
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"user_threepids_grandfather", self._bg_user_threepids_grandfather
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
- @defer.inlineCallbacks
- def _background_update_set_deactivated_flag(self, progress, batch_size):
+ async def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
"""
@@ -861,7 +826,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
(last_user, batch_size),
)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if not rows:
return True, 0
@@ -875,7 +840,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
logger.info("Marked %d rows as deactivated", rows_processed_nb)
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
)
@@ -884,17 +849,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
else:
return False, len(rows)
- end, nb_processed = yield self.db.runInteraction(
+ end, nb_processed = await self.db_pool.runInteraction(
"users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
)
if end:
- yield self.db.updates._end_background_update("users_set_deactivated_flag")
+ await self.db_pool.updates._end_background_update(
+ "users_set_deactivated_flag"
+ )
return nb_processed
- @defer.inlineCallbacks
- def _bg_user_threepids_grandfather(self, progress, batch_size):
+ async def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track
this.
@@ -915,20 +881,21 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
- yield self.db.updates._end_background_update("user_threepids_grandfather")
+ await self.db_pool.updates._end_background_update("user_threepids_grandfather")
return 1
class RegistrationStore(RegistrationBackgroundUpdateStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RegistrationStore, self).__init__(database, db_conn, hs)
self._account_validity = hs.config.account_validity
+ self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
if self._account_validity.enabled:
self._clock.call_later(
@@ -949,23 +916,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
- @defer.inlineCallbacks
- def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
+ async def add_access_token_to_user(
+ self,
+ user_id: str,
+ token: str,
+ device_id: Optional[str],
+ valid_until_ms: Optional[int],
+ ) -> None:
"""Adds an access token for the given user.
Args:
- user_id (str): The user ID.
- token (str): The new access token to add.
- device_id (str): ID of the device to associate with the access
- token
- valid_until_ms (int|None): when the token is valid until. None for
- no expiry.
+ user_id: The user ID.
+ token: The new access token to add.
+ device_id: ID of the device to associate with the access token
+ valid_until_ms: when the token is valid until. None for no expiry.
Raises:
StoreError if there was a problem adding this.
"""
next_id = self._access_tokens_id_gen.get_next()
- yield self.db.simple_insert(
+ await self.db_pool.simple_insert(
"access_tokens",
{
"id": next_id,
@@ -977,40 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_access_token_to_user",
)
- def register_user(
+ async def register_user(
self,
- user_id,
- password_hash=None,
- was_guest=False,
- make_guest=False,
- appservice_id=None,
- create_profile_with_displayname=None,
- admin=False,
- user_type=None,
- ):
+ user_id: str,
+ password_hash: Optional[str] = None,
+ was_guest: bool = False,
+ make_guest: bool = False,
+ appservice_id: Optional[str] = None,
+ create_profile_with_displayname: Optional[str] = None,
+ admin: bool = False,
+ user_type: Optional[str] = None,
+ shadow_banned: bool = False,
+ ) -> None:
"""Attempts to register an account.
Args:
- user_id (str): The desired user ID to register.
- password_hash (str|None): Optional. The password hash for this user.
- was_guest (bool): Optional. Whether this is a guest account being
- upgraded to a non-guest account.
- make_guest (boolean): True if the the new user should be guest,
- false to add a regular user account.
- appservice_id (str): The ID of the appservice registering the user.
- create_profile_with_displayname (unicode): Optionally create a profile for
+ user_id: The desired user ID to register.
+ password_hash: Optional. The password hash for this user.
+ was_guest: Whether this is a guest account being upgraded to a
+ non-guest account.
+ make_guest: True if the the new user should be guest, false to add a
+ regular user account.
+ appservice_id: The ID of the appservice registering the user.
+ create_profile_with_displayname: Optionally create a profile for
the user, setting their displayname to the given value
- admin (boolean): is an admin user?
- user_type (str|None): type of user. One of the values from
- api.constants.UserTypes, or None for a normal user.
+ admin: is an admin user?
+ user_type: type of user. One of the values from api.constants.UserTypes,
+ or None for a normal user.
+ shadow_banned: Whether the user is shadow-banned, i.e. they may be
+ told their requests succeeded but we ignore them.
Raises:
StoreError if the user_id could not be registered.
-
- Returns:
- Deferred
"""
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"register_user",
self._register_user,
user_id,
@@ -1021,6 +991,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname,
admin,
user_type,
+ shadow_banned,
)
def _register_user(
@@ -1034,6 +1005,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
create_profile_with_displayname,
admin,
user_type,
+ shadow_banned,
):
user_id_obj = UserID.from_string(user_id)
@@ -1044,7 +1016,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
# Ensure that the guest user actually exists
# ``allow_none=False`` makes this raise an exception
# if the row isn't in the database.
- self.db.simple_select_one_txn(
+ self.db_pool.simple_select_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@@ -1052,7 +1024,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
allow_none=False,
)
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@@ -1063,10 +1035,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
+ "shadow_banned": shadow_banned,
},
)
else:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
"users",
values={
@@ -1077,6 +1050,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
+ "shadow_banned": shadow_banned,
},
)
@@ -1109,11 +1083,10 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- txn.call_after(self.is_guest.invalidate, (user_id,))
- def record_user_external_id(
+ async def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
- ) -> Deferred:
+ ) -> None:
"""Record a mapping from an external user id to a mxid
Args:
@@ -1121,7 +1094,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
- return self.db.simple_insert(
+ await self.db_pool.simple_insert(
table="user_external_ids",
values={
"auth_provider": auth_provider,
@@ -1131,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="record_user_external_id",
)
- def user_set_password_hash(self, user_id, password_hash):
+ async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
@@ -1139,29 +1112,30 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def user_set_password_hash_txn(txn):
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn, "users", {"name": user_id}, {"password_hash": password_hash}
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
)
- def user_set_consent_version(self, user_id, consent_version):
+ async def user_set_consent_version(
+ self, user_id: str, consent_version: str
+ ) -> None:
"""Updates the user table to record privacy policy consent
Args:
- user_id (str): full mxid of the user to update
- consent_version (str): version of the policy the user has consented
- to
+ user_id: full mxid of the user to update
+ consent_version: version of the policy the user has consented to
Raises:
StoreError(404) if user not found
"""
def f(txn):
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="users",
keyvalues={"name": user_id},
@@ -1169,23 +1143,24 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.db.runInteraction("user_set_consent_version", f)
+ await self.db_pool.runInteraction("user_set_consent_version", f)
- def user_set_consent_server_notice_sent(self, user_id, consent_version):
+ async def user_set_consent_server_notice_sent(
+ self, user_id: str, consent_version: str
+ ) -> None:
"""Updates the user table to record that we have sent the user a server
notice about privacy policy consent
Args:
- user_id (str): full mxid of the user to update
- consent_version (str): version of the policy we have notified the
- user about
+ user_id: full mxid of the user to update
+ consent_version: version of the policy we have notified the user about
Raises:
StoreError(404) if user not found
"""
def f(txn):
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="users",
keyvalues={"name": user_id},
@@ -1193,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.db.runInteraction("user_set_consent_server_notice_sent", f)
+ await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
- def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
+ async def user_delete_access_tokens(
+ self,
+ user_id: str,
+ except_token_id: Optional[str] = None,
+ device_id: Optional[str] = None,
+ ) -> List[Tuple[str, int, Optional[str]]]:
"""
Invalidate access tokens belonging to a user
Args:
- user_id (str): ID of user the tokens belong to
- except_token_id (str): list of access_tokens IDs which should
- *not* be deleted
- device_id (str|None): ID of device the tokens are associated with.
+ user_id: ID of user the tokens belong to
+ except_token_id: access_tokens ID which should *not* be deleted
+ device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
- defer.Deferred[list[str, int, str|None, int]]: a list of
- (token, token id, device id) for each of the deleted tokens
+ A tuple of (token, token id, device id) for each of the deleted tokens
"""
def f(txn):
@@ -1239,11 +1217,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return tokens_and_devices
- return self.db.runInteraction("user_delete_access_tokens", f)
+ return await self.db_pool.runInteraction("user_delete_access_tokens", f)
- def delete_access_token(self, access_token):
+ async def delete_access_token(self, access_token: str) -> None:
def f(txn):
- self.db.simple_delete_one_txn(
+ self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
)
@@ -1251,11 +1229,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, self.get_user_by_access_token, (access_token,)
)
- return self.db.runInteraction("delete_access_token", f)
+ await self.db_pool.runInteraction("delete_access_token", f)
- @cachedInlineCallbacks()
- def is_guest(self, user_id):
- res = yield self.db.simple_select_one_onecol(
+ @cached()
+ async def is_guest(self, user_id: str) -> bool:
+ res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
@@ -1265,36 +1243,36 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return res if res else False
- def add_user_pending_deactivation(self, user_id):
+ async def add_user_pending_deactivation(self, user_id: str) -> None:
"""
Adds a user to the table of users who need to be parted from all the rooms they're
in
"""
- return self.db.simple_insert(
+ await self.db_pool.simple_insert(
"users_pending_deactivation",
values={"user_id": user_id},
desc="add_user_pending_deactivation",
)
- def del_user_pending_deactivation(self, user_id):
+ async def del_user_pending_deactivation(self, user_id: str) -> None:
"""
Removes the given user to the table of users who need to be parted from all the
rooms they're in, effectively marking that user as fully deactivated.
"""
# XXX: This should be simple_delete_one but we failed to put a unique index on
# the table, so somehow duplicate entries have ended up in it.
- return self.db.simple_delete(
+ await self.db_pool.simple_delete(
"users_pending_deactivation",
keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation",
)
- def get_user_pending_deactivation(self):
+ async def get_user_pending_deactivation(self) -> Optional[str]:
"""
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
- return self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
@@ -1302,29 +1280,30 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="get_users_pending_deactivation",
)
- def validate_threepid_session(self, session_id, client_secret, token, current_ts):
+ async def validate_threepid_session(
+ self, session_id: str, client_secret: str, token: str, current_ts: int
+ ) -> Optional[str]:
"""Attempt to validate a threepid session using a token
Args:
- session_id (str): The id of a validation session
- client_secret (str): A unique string provided by the client to
- help identify this validation attempt
- token (str): A validation token
- current_ts (int): The current unix time in milliseconds. Used for
- checking token expiry status
+ session_id: The id of a validation session
+ client_secret: A unique string provided by the client to help identify
+ this validation attempt
+ token: A validation token
+ current_ts: The current unix time in milliseconds. Used for checking
+ token expiry status
Raises:
ThreepidValidationError: if a matching validation token was not found or has
expired
Returns:
- deferred str|None: A str representing a link to redirect the user
- to if there is one.
+ A str representing a link to redirect the user to if there is one.
"""
# Insert everything into a transaction in order to run atomically
def validate_threepid_session_txn(txn):
- row = self.db.simple_select_one_txn(
+ row = self.db_pool.simple_select_one_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1333,16 +1312,23 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
if not row:
- raise ThreepidValidationError(400, "Unknown session_id")
+ if self._ignore_unknown_session_error:
+ # If we need to inhibit the error caused by an incorrect session ID,
+ # use None as placeholder values for the client secret and the
+ # validation timestamp.
+ # It shouldn't be an issue because they're both only checked after
+ # the token check, which should fail. And if it doesn't for some
+ # reason, the next check is on the client secret, which is NOT NULL,
+ # so we don't have to worry about the client secret matching by
+ # accident.
+ row = {"client_secret": None, "validated_at": None}
+ else:
+ raise ThreepidValidationError(400, "Unknown session_id")
+
retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"]
- if retrieved_client_secret != client_secret:
- raise ThreepidValidationError(
- 400, "This client_secret does not match the provided session_id"
- )
-
- row = self.db.simple_select_one_txn(
+ row = self.db_pool.simple_select_one_txn(
txn,
table="threepid_validation_token",
keyvalues={"session_id": session_id, "token": token},
@@ -1357,6 +1343,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
expires = row["expires"]
next_link = row["next_link"]
+ if retrieved_client_secret != client_secret:
+ raise ThreepidValidationError(
+ 400, "This client_secret does not match the provided session_id"
+ )
+
# If the session is already validated, no need to revalidate
if validated_at:
return next_link
@@ -1367,7 +1358,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
# Looks good. Validate the session
- self.db.simple_update_txn(
+ self.db_pool.simple_update_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1377,78 +1368,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return next_link
# Return next_link if it exists
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"validate_threepid_session_txn", validate_threepid_session_txn
)
- def upsert_threepid_validation_session(
+ async def start_or_continue_validation_session(
self,
- medium,
- address,
- client_secret,
- send_attempt,
- session_id,
- validated_at=None,
- ):
- """Upsert a threepid validation session
- Args:
- medium (str): The medium of the 3PID
- address (str): The address of the 3PID
- client_secret (str): A unique string provided by the client to
- help identify this validation attempt
- send_attempt (int): The latest send_attempt on this session
- session_id (str): The id of this validation session
- validated_at (int|None): The unix timestamp in milliseconds of
- when the session was marked as valid
- """
- insertion_values = {
- "medium": medium,
- "address": address,
- "client_secret": client_secret,
- }
-
- if validated_at:
- insertion_values["validated_at"] = validated_at
-
- return self.db.simple_upsert(
- table="threepid_validation_session",
- keyvalues={"session_id": session_id},
- values={"last_send_attempt": send_attempt},
- insertion_values=insertion_values,
- desc="upsert_threepid_validation_session",
- )
-
- def start_or_continue_validation_session(
- self,
- medium,
- address,
- session_id,
- client_secret,
- send_attempt,
- next_link,
- token,
- token_expires,
- ):
+ medium: str,
+ address: str,
+ session_id: str,
+ client_secret: str,
+ send_attempt: int,
+ next_link: Optional[str],
+ token: str,
+ token_expires: int,
+ ) -> None:
"""Creates a new threepid validation session if it does not already
exist and associates a new validation token with it
Args:
- medium (str): The medium of the 3PID
- address (str): The address of the 3PID
- session_id (str): The id of this validation session
- client_secret (str): A unique string provided by the client to
- help identify this validation attempt
- send_attempt (int): The latest send_attempt on this session
- next_link (str|None): The link to redirect the user to upon
- successful validation
- token (str): The validation token
- token_expires (int): The timestamp for which after the token
- will no longer be valid
+ medium: The medium of the 3PID
+ address: The address of the 3PID
+ session_id: The id of this validation session
+ client_secret: A unique string provided by the client to help
+ identify this validation attempt
+ send_attempt: The latest send_attempt on this session
+ next_link: The link to redirect the user to upon successful validation
+ token: The validation token
+ token_expires: The timestamp for which after the token will no
+ longer be valid
"""
def start_or_continue_validation_session_txn(txn):
# Create or update a validation session
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1461,7 +1414,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
# Create a new validation token with this session ID
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="threepid_validation_token",
values={
@@ -1472,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"start_or_continue_validation_session",
start_or_continue_validation_session_txn,
)
- def cull_expired_threepid_validation_tokens(self):
+ async def cull_expired_threepid_validation_tokens(self) -> None:
"""Remove threepid validation tokens with expiry dates that have passed"""
def cull_expired_threepid_validation_tokens_txn(txn, ts):
@@ -1485,24 +1438,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
DELETE FROM threepid_validation_token WHERE
expires < ?
"""
- return txn.execute(sql, (ts,))
+ txn.execute(sql, (ts,))
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
)
- @defer.inlineCallbacks
- def set_user_deactivated_status(self, user_id, deactivated):
+ async def set_user_deactivated_status(
+ self, user_id: str, deactivated: bool
+ ) -> None:
"""Set the `deactivated` property for the provided user to the provided value.
Args:
- user_id (str): The ID of the user to set the status for.
- deactivated (bool): The value to set for `deactivated`.
+ user_id: The ID of the user to set the status for.
+ deactivated: The value to set for `deactivated`.
"""
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
@@ -1510,7 +1464,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@@ -1519,9 +1473,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
+ txn.call_after(self.is_guest.invalidate, (user_id,))
- @defer.inlineCallbacks
- def _set_expiration_date_when_missing(self):
+ async def _set_expiration_date_when_missing(self):
"""
Retrieves the list of registered users that don't have an expiration date, and
adds an expiration date for each of them.
@@ -1538,14 +1492,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
txn.execute(sql, [])
- res = self.db.cursor_to_dict(txn)
+ res = self.db_pool.cursor_to_dict(txn)
if res:
for user in res:
self.set_expiration_date_for_user_txn(
txn, user["name"], use_delta=True
)
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"get_users_with_no_expiration_date",
select_users_with_no_expiration_date_txn,
)
@@ -1569,9 +1523,32 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
expiration_ts,
)
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
"account_validity",
keyvalues={"user_id": user_id},
values={"expiration_ts_ms": expiration_ts, "email_sent": False},
)
+
+
+def find_max_generated_user_id_localpart(cur: Cursor) -> int:
+ """
+ Gets the localpart of the max current generated user ID.
+
+ Generated user IDs are integers, so we find the largest integer user ID
+ already taken and return that.
+ """
+
+ # We bound between '@0' and '@a' to avoid pulling the entire table
+ # out.
+ cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
+
+ regex = re.compile(r"^@(\d+):")
+
+ max_found = 0
+
+ for (user_id,) in cur:
+ match = regex.search(user_id)
+ if match:
+ max_found = max(int(match.group(1)), max_found)
+ return max_found
diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/databases/main/rejections.py
index 27e5a2084a..1e361aaa9a 100644
--- a/synapse/storage/data_stores/main/rejections.py
+++ b/synapse/storage/databases/main/rejections.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Optional
from synapse.storage._base import SQLBaseStore
@@ -21,8 +22,8 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
- def get_rejection_reason(self, event_id):
- return self.db.simple_select_one_onecol(
+ async def get_rejection_reason(self, event_id: str) -> Optional[str]:
+ return await self.db_pool.simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={"event_id": event_id},
diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/databases/main/relations.py
index 7d477f8d01..5cd61547f7 100644
--- a/synapse/storage/data_stores/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -14,56 +14,53 @@
# limitations under the License.
import logging
+from typing import Optional
import attr
from synapse.api.constants import RelationTypes
+from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.stream import generate_pagination_where_clause
+from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
PaginationChunk,
RelationPaginationToken,
)
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
class RelationsWorkerStore(SQLBaseStore):
@cached(tree=True)
- def get_relations_for_event(
+ async def get_relations_for_event(
self,
- event_id,
- relation_type=None,
- event_type=None,
- aggregation_key=None,
- limit=5,
- direction="b",
- from_token=None,
- to_token=None,
- ):
+ event_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ aggregation_key: Optional[str] = None,
+ limit: int = 5,
+ direction: str = "b",
+ from_token: Optional[RelationPaginationToken] = None,
+ to_token: Optional[RelationPaginationToken] = None,
+ ) -> PaginationChunk:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
- event_id (str): Fetch events that relate to this event ID.
- relation_type (str|None): Only fetch events with this relation
- type, if given.
- event_type (str|None): Only fetch events with this event type, if
- given.
- aggregation_key (str|None): Only fetch events with this aggregation
- key, if given.
- limit (int): Only fetch the most recent `limit` events.
- direction (str): Whether to fetch the most recent first (`"b"`) or
- the oldest first (`"f"`).
- from_token (RelationPaginationToken|None): Fetch rows from the given
- token, or from the start if None.
- to_token (RelationPaginationToken|None): Fetch rows up to the given
- token, or up to the end if None.
+ event_id: Fetch events that relate to this event ID.
+ relation_type: Only fetch events with this relation type, if given.
+ event_type: Only fetch events with this event type, if given.
+ aggregation_key: Only fetch events with this aggregation key, if given.
+ limit: Only fetch the most recent `limit` events.
+ direction: Whether to fetch the most recent first (`"b"`) or the
+ oldest first (`"f"`).
+ from_token: Fetch rows from the given token, or from the start if None.
+ to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
- Deferred[PaginationChunk]: List of event IDs that match relations
- requested. The rows are of the form `{"event_id": "..."}`.
+ List of event IDs that match relations requested. The rows are of
+ the form `{"event_id": "..."}`.
"""
where_clause = ["relates_to_id = ?"]
@@ -129,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
@cached(tree=True)
- def get_aggregation_groups_for_event(
+ async def get_aggregation_groups_for_event(
self,
- event_id,
- event_type=None,
- limit=5,
- direction="b",
- from_token=None,
- to_token=None,
- ):
+ event_id: str,
+ event_type: Optional[str] = None,
+ limit: int = 5,
+ direction: str = "b",
+ from_token: Optional[AggregationPaginationToken] = None,
+ to_token: Optional[AggregationPaginationToken] = None,
+ ) -> PaginationChunk:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
@@ -150,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore):
on an event.
Args:
- event_id (str): Fetch events that relate to this event ID.
- event_type (str|None): Only fetch events with this event type, if
- given.
- limit (int): Only fetch the `limit` groups.
- direction (str): Whether to fetch the highest count first (`"b"`) or
+ event_id: Fetch events that relate to this event ID.
+ event_type: Only fetch events with this event type, if given.
+ limit: Only fetch the `limit` groups.
+ direction: Whether to fetch the highest count first (`"b"`) or
the lowest count first (`"f"`).
- from_token (AggregationPaginationToken|None): Fetch rows from the
- given token, or from the start if None.
- to_token (AggregationPaginationToken|None): Fetch rows up to the
- given token, or up to the end if None.
-
+ from_token: Fetch rows from the given token, or from the start if None.
+ to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
- Deferred[PaginationChunk]: List of groups of annotations that
- match. Each row is a dict with `type`, `key` and `count` fields.
+ List of groups of annotations that match. Each row is a dict with
+ `type`, `key` and `count` fields.
"""
where_clause = ["relates_to_id = ?", "relation_type = ?"]
@@ -223,22 +216,22 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
- @cachedInlineCallbacks()
- def get_applicable_edit(self, event_id):
+ @cached()
+ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
"""Get the most recent edit (if any) that has happened for the given
event.
Correctly handles checking whether edits were allowed to happen.
Args:
- event_id (str): The original event ID
+ event_id: The original event ID
Returns:
- Deferred[EventBase|None]: Returns the most recent edit, if any.
+ The most recent edit, if any.
"""
# We only allow edits for `m.room.message` events that have the same sender
@@ -268,28 +261,29 @@ class RelationsWorkerStore(SQLBaseStore):
if row:
return row[0]
- edit_id = yield self.db.runInteraction(
+ edit_id = await self.db_pool.runInteraction(
"get_applicable_edit", _get_applicable_edit_txn
)
if not edit_id:
- return
+ return None
- edit_event = yield self.get_event(edit_id, allow_none=True)
- return edit_event
+ return await self.get_event(edit_id, allow_none=True)
- def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+ async def has_user_annotated_event(
+ self, parent_id: str, event_type: str, aggregation_key: str, sender: str
+ ) -> bool:
"""Check if a user has already annotated an event with the same key
(e.g. already liked an event).
Args:
- parent_id (str): The event being annotated
- event_type (str): The event type of the annotation
- aggregation_key (str): The aggregation key of the annotation
- sender (str): The sender of the annotation
+ parent_id: The event being annotated
+ event_type: The event type of the annotation
+ aggregation_key: The aggregation key of the annotation
+ sender: The sender of the annotation
Returns:
- Deferred[bool]
+ True if the event is already annotated.
"""
sql = """
@@ -318,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore):
return bool(txn.fetchone())
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/databases/main/room.py
index 46f643c6b9..717df97301 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -21,26 +21,19 @@ from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
-from canonicaljson import json
-
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.search import SearchStore
-from synapse.storage.database import Database, LoggingTransaction
-from synapse.types import ThirdPartyInstanceID
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.search import SearchStore
+from synapse.types import JsonDict, ThirdPartyInstanceID
+from synapse.util import json_encoder
+from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
-OpsLevel = collections.namedtuple(
- "OpsLevel", ("ban_level", "kick_level", "redact_level")
-)
-
RatelimitOverride = collections.namedtuple(
"RatelimitOverride", ("messages_per_second", "burst_count")
)
@@ -75,20 +68,20 @@ class RoomSortOrder(Enum):
class RoomWorkerStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomWorkerStore, self).__init__(database, db_conn, hs)
self.config = hs.config
- def get_room(self, room_id):
+ async def get_room(self, room_id: str) -> dict:
"""Retrieve a room.
Args:
- room_id (str): The ID of the room to retrieve.
+ room_id: The ID of the room to retrieve.
Returns:
A dict containing the room information, or None if the room is unknown.
"""
- return self.db.simple_select_one(
+ return await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
@@ -96,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore):
allow_none=True,
)
- def get_room_with_stats(self, room_id: str):
+ async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve room with statistics.
Args:
@@ -118,30 +111,39 @@ class RoomWorkerStore(SQLBaseStore):
WHERE room_id = ?
"""
txn.execute(sql, [room_id])
- res = self.db.cursor_to_dict(txn)[0]
+ # Catch error if sql returns empty result to return "None" instead of an error
+ try:
+ res = self.db_pool.cursor_to_dict(txn)[0]
+ except IndexError:
+ return None
+
res["federatable"] = bool(res["federatable"])
res["public"] = bool(res["public"])
return res
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_room_with_stats", get_room_with_stats_txn, room_id
)
- def get_public_room_ids(self):
- return self.db.simple_select_onecol(
+ async def get_public_room_ids(self) -> List[str]:
+ return await self.db_pool.simple_select_onecol(
table="rooms",
keyvalues={"is_public": True},
retcol="room_id",
desc="get_public_room_ids",
)
- def count_public_rooms(self, network_tuple, ignore_non_federatable):
+ async def count_public_rooms(
+ self,
+ network_tuple: Optional[ThirdPartyInstanceID],
+ ignore_non_federatable: bool,
+ ) -> int:
"""Counts the number of public rooms as tracked in the room_stats_current
and room_stats_state table.
Args:
- network_tuple (ThirdPartyInstanceID|None)
- ignore_non_federatable (bool): If true filters out non-federatable rooms
+ network_tuple
+ ignore_non_federatable: If true filters out non-federatable rooms
"""
def _count_public_rooms_txn(txn):
@@ -185,10 +187,11 @@ class RoomWorkerStore(SQLBaseStore):
txn.execute(sql, query_args)
return txn.fetchone()[0]
- return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
+ return await self.db_pool.runInteraction(
+ "count_public_rooms", _count_public_rooms_txn
+ )
- @defer.inlineCallbacks
- def get_largest_public_rooms(
+ async def get_largest_public_rooms(
self,
network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict],
@@ -318,21 +321,21 @@ class RoomWorkerStore(SQLBaseStore):
def _get_largest_public_rooms_txn(txn):
txn.execute(sql, query_args)
- results = self.db.cursor_to_dict(txn)
+ results = self.db_pool.cursor_to_dict(txn)
if not forwards:
results.reverse()
return results
- ret_val = yield self.db.runInteraction(
+ ret_val = await self.db_pool.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn
)
- defer.returnValue(ret_val)
+ return ret_val
@cached(max_entries=10000)
- def is_room_blocked(self, room_id):
- return self.db.simple_select_one_onecol(
+ async def is_room_blocked(self, room_id: str) -> Optional[bool]:
+ return await self.db_pool.simple_select_one_onecol(
table="blocked_rooms",
keyvalues={"room_id": room_id},
retcol="1",
@@ -500,12 +503,12 @@ class RoomWorkerStore(SQLBaseStore):
room_count = txn.fetchone()
return rooms, room_count[0]
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_rooms_paginate", _get_rooms_paginate_txn,
)
- @cachedInlineCallbacks(max_entries=10000)
- def get_ratelimit_for_user(self, user_id):
+ @cached(max_entries=10000)
+ async def get_ratelimit_for_user(self, user_id):
"""Check if there are any overrides for ratelimiting for the given
user
@@ -517,7 +520,7 @@ class RoomWorkerStore(SQLBaseStore):
of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely.
"""
- row = yield self.db.simple_select_one(
+ row = await self.db_pool.simple_select_one(
table="ratelimit_override",
keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"),
@@ -533,8 +536,8 @@ class RoomWorkerStore(SQLBaseStore):
else:
return None
- @cachedInlineCallbacks()
- def get_retention_policy_for_room(self, room_id):
+ @cached()
+ async def get_retention_policy_for_room(self, room_id):
"""Get the retention policy for a given room.
If no retention policy has been found for this room, returns a policy defined
@@ -559,21 +562,19 @@ class RoomWorkerStore(SQLBaseStore):
(room_id,),
)
- return self.db.cursor_to_dict(txn)
+ return self.db_pool.cursor_to_dict(txn)
- ret = yield self.db.runInteraction(
+ ret = await self.db_pool.runInteraction(
"get_retention_policy_for_room", get_retention_policy_for_room_txn,
)
# If we don't know this room ID, ret will be None, in this case return the default
# policy.
if not ret:
- defer.returnValue(
- {
- "min_lifetime": self.config.retention_default_min_lifetime,
- "max_lifetime": self.config.retention_default_max_lifetime,
- }
- )
+ return {
+ "min_lifetime": self.config.retention_default_min_lifetime,
+ "max_lifetime": self.config.retention_default_max_lifetime,
+ }
row = ret[0]
@@ -587,17 +588,16 @@ class RoomWorkerStore(SQLBaseStore):
if row["max_lifetime"] is None:
row["max_lifetime"] = self.config.retention_default_max_lifetime
- defer.returnValue(row)
+ return row
- def get_media_mxcs_in_room(self, room_id):
+ async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
- room_id (str)
+ room_id
Returns:
- The local and remote media as a lists of tuples where the key is
- the hostname and the value is the media ID.
+ The local and remote media as a lists of the media IDs.
"""
def _get_media_mxcs_in_room_txn(txn):
@@ -613,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
)
- def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+ async def quarantine_media_ids_in_room(
+ self, room_id: str, quarantined_by: str
+ ) -> int:
"""For a room loops through all events with media and quarantines
the associated media
"""
@@ -626,37 +628,11 @@ class RoomWorkerStore(SQLBaseStore):
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
- total_media_quarantined = 0
-
- # Now update all the tables to set the quarantined_by flag
-
- txn.executemany(
- """
- UPDATE local_media_repository
- SET quarantined_by = ?
- WHERE media_id = ?
- """,
- ((quarantined_by, media_id) for media_id in local_mxcs),
- )
-
- txn.executemany(
- """
- UPDATE remote_media_cache
- SET quarantined_by = ?
- WHERE media_origin = ? AND media_id = ?
- """,
- (
- (quarantined_by, origin, media_id)
- for origin, media_id in remote_mxcs
- ),
+ return self._quarantine_media_txn(
+ txn, local_mxcs, remote_mxcs, quarantined_by
)
- total_media_quarantined += len(local_mxcs)
- total_media_quarantined += len(remote_mxcs)
-
- return total_media_quarantined
-
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
@@ -691,7 +667,7 @@ class RoomWorkerStore(SQLBaseStore):
next_token = None
for stream_ordering, content_json in txn:
next_token = stream_ordering
- event_json = json.loads(content_json)
+ event_json = db_to_json(content_json)
content = event_json["content"]
content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url")
@@ -719,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
- def quarantine_media_by_id(
+ async def quarantine_media_by_id(
self, server_name: str, media_id: str, quarantined_by: str,
- ):
+ ) -> int:
"""quarantines a single local or remote media id
Args:
@@ -740,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by
)
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_id_txn
)
- def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
+ async def quarantine_media_ids_by_user(
+ self, user_id: str, quarantined_by: str
+ ) -> int:
"""quarantines all local media associated with a single user
Args:
@@ -756,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore):
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_user_txn
)
@@ -805,17 +783,17 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
The total number of media items quarantined
"""
- total_media_quarantined = 0
-
# Update all the tables to set the quarantined_by flag
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
- WHERE media_id = ?
+ WHERE media_id = ? AND safe_from_quarantine = ?
""",
- ((quarantined_by, media_id) for media_id in local_mxcs),
+ ((quarantined_by, media_id, False) for media_id in local_mxcs),
)
+ # Note that a rowcount of -1 can be used to indicate no rows were affected.
+ total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
txn.executemany(
"""
@@ -825,13 +803,36 @@ class RoomWorkerStore(SQLBaseStore):
""",
((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
)
-
- total_media_quarantined += len(local_mxcs)
- total_media_quarantined += len(remote_mxcs)
+ total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
return total_media_quarantined
- def get_all_new_public_rooms(self, prev_id, current_id, limit):
+ async def get_all_new_public_rooms(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for public rooms replication stream.
+
+ Args:
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
+ Returns:
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
+ """
+ if last_id == current_id:
+ return [], current_id, False
+
def get_all_new_public_rooms(txn):
sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id
@@ -841,13 +842,17 @@ class RoomWorkerStore(SQLBaseStore):
LIMIT ?
"""
- txn.execute(sql, (prev_id, current_id, limit))
- return txn.fetchall()
+ txn.execute(sql, (last_id, current_id, limit))
+ updates = [(row[0], row[1:]) for row in txn]
+ limited = False
+ upto_token = current_id
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
- if prev_id == current_id:
- return defer.succeed([])
+ return updates, upto_token, limited
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)
@@ -856,27 +861,26 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.config = hs.config
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"insert_room_retention", self._background_insert_retention,
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE,
self._remove_tombstoned_rooms_from_directory,
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.ADD_ROOMS_ROOM_VERSION_COLUMN,
self._background_add_rooms_room_version_column,
)
- @defer.inlineCallbacks
- def _background_insert_retention(self, progress, batch_size):
+ async def _background_insert_retention(self, progress, batch_size):
"""Retrieves a list of all rooms within a range and inserts an entry for each of
them into the room_retention table.
NULLs the property's columns if missing from the retention event in the room's
@@ -900,7 +904,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
(last_room, batch_size),
)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if not rows:
return True
@@ -909,10 +913,10 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
if not row["json"]:
retention_policy = {}
else:
- ev = json.loads(row["json"])
- retention_policy = json.dumps(ev["content"])
+ ev = db_to_json(row["json"])
+ retention_policy = ev["content"]
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn=txn,
table="room_retention",
values={
@@ -925,7 +929,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
logger.info("Inserted %d rows into room_retention", len(rows))
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
)
@@ -934,14 +938,14 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
else:
return False
- end = yield self.db.runInteraction(
+ end = await self.db_pool.runInteraction(
"insert_room_retention", _background_insert_retention_txn,
)
if end:
- yield self.db.updates._end_background_update("insert_room_retention")
+ await self.db_pool.updates._end_background_update("insert_room_retention")
- defer.returnValue(batch_size)
+ return batch_size
async def _background_add_rooms_room_version_column(
self, progress: dict, batch_size: int
@@ -965,7 +969,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
updates = []
for room_id, event_json in txn:
- event_dict = json.loads(event_json)
+ event_dict = db_to_json(event_json)
room_version_id = event_dict.get("content", {}).get(
"room_version", RoomVersions.V1.identifier
)
@@ -983,7 +987,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
# mainly for paranoia as much badness would happen if we don't
# insert the row and then try and get the room version for the
# room.
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="rooms",
keyvalues={"room_id": room_id},
@@ -992,19 +996,19 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
)
new_last_room_id = room_id
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id}
)
return False
- end = await self.db.runInteraction(
+ end = await self.db_pool.runInteraction(
"_background_add_rooms_room_version_column",
_background_add_rooms_room_version_column_txn,
)
if end:
- await self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.ADD_ROOMS_ROOM_VERSION_COLUMN
)
@@ -1038,12 +1042,12 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return [row[0] for row in txn]
- rooms = await self.db.runInteraction(
+ rooms = await self.db_pool.runInteraction(
"get_tombstoned_directory_rooms", _get_rooms
)
if not rooms:
- await self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE
)
return 0
@@ -1052,7 +1056,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
logger.info("Removing tombstoned room %s from the directory", room_id)
await self.set_room_is_public(room_id, False)
- await self.db.updates._background_update_progress(
+ await self.db_pool.updates._background_update_progress(
self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]}
)
@@ -1068,7 +1072,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomStore, self).__init__(database, db_conn, hs)
self.config = hs.config
@@ -1079,7 +1083,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
Called when we join a room over federation, and overwrites any room version
currently in the table.
"""
- await self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
desc="upsert_room_on_join",
table="rooms",
keyvalues={"room_id": room_id},
@@ -1090,8 +1094,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
lock=False,
)
- @defer.inlineCallbacks
- def store_room(
+ async def store_room(
self,
room_id: str,
room_creator_user_id: str,
@@ -1112,7 +1115,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
try:
def store_room_txn(txn, next_id):
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
"rooms",
{
@@ -1123,7 +1126,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
if is_public:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="public_room_list_stream",
values={
@@ -1133,8 +1136,10 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
- yield self.db.runInteraction("store_room_txn", store_room_txn, next_id)
+ with await self._public_room_id_gen.get_next() as next_id:
+ await self.db_pool.runInteraction(
+ "store_room_txn", store_room_txn, next_id
+ )
except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
@@ -1144,7 +1149,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
When we receive an invite over federation, store the version of the room if we
don't already know the room version.
"""
- await self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
desc="maybe_store_room_on_invite",
table="rooms",
keyvalues={"room_id": room_id},
@@ -1159,17 +1164,16 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
lock=False,
)
- @defer.inlineCallbacks
- def set_room_is_public(self, room_id, is_public):
+ async def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id):
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="rooms",
keyvalues={"room_id": room_id},
updatevalues={"is_public": is_public},
)
- entries = self.db.simple_select_list_txn(
+ entries = self.db_pool.simple_select_list_txn(
txn,
table="public_room_list_stream",
keyvalues={
@@ -1187,7 +1191,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="public_room_list_stream",
values={
@@ -1199,14 +1203,13 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
- yield self.db.runInteraction(
+ with await self._public_room_id_gen.get_next() as next_id:
+ await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
self.hs.get_notifier().on_new_replication_data()
- @defer.inlineCallbacks
- def set_room_is_public_appservice(
+ async def set_room_is_public_appservice(
self, room_id, appservice_id, network_id, is_public
):
"""Edit the appservice/network specific public room list.
@@ -1227,7 +1230,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def set_room_is_public_appservice_txn(txn, next_id):
if is_public:
try:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="appservice_room_list",
values={
@@ -1240,7 +1243,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
# We've already inserted, nothing to do.
return
else:
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="appservice_room_list",
keyvalues={
@@ -1250,7 +1253,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- entries = self.db.simple_select_list_txn(
+ entries = self.db_pool.simple_select_list_txn(
txn,
table="public_room_list_stream",
keyvalues={
@@ -1268,7 +1271,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="public_room_list_stream",
values={
@@ -1280,16 +1283,16 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with self._public_room_id_gen.get_next() as next_id:
- yield self.db.runInteraction(
+ with await self._public_room_id_gen.get_next() as next_id:
+ await self.db_pool.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
next_id,
)
self.hs.get_notifier().on_new_replication_data()
- def get_room_count(self):
- """Retrieve a list of all rooms
+ async def get_room_count(self) -> int:
+ """Retrieve the total number of rooms.
"""
def f(txn):
@@ -1298,13 +1301,19 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
row = txn.fetchone()
return row[0] or 0
- return self.db.runInteraction("get_rooms", f)
+ return await self.db_pool.runInteraction("get_rooms", f)
- def add_event_report(
- self, room_id, event_id, user_id, reason, content, received_ts
- ):
+ async def add_event_report(
+ self,
+ room_id: str,
+ event_id: str,
+ user_id: str,
+ reason: str,
+ content: JsonDict,
+ received_ts: int,
+ ) -> None:
next_id = self._event_reports_id_gen.get_next()
- return self.db.simple_insert(
+ await self.db_pool.simple_insert(
table="event_reports",
values={
"id": next_id,
@@ -1313,7 +1322,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
"event_id": event_id,
"user_id": user_id,
"reason": reason,
- "content": json.dumps(content),
+ "content": json_encoder.encode(content),
},
desc="add_event_report",
)
@@ -1321,52 +1330,47 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
- @defer.inlineCallbacks
- def block_room(self, room_id, user_id):
+ async def block_room(self, room_id: str, user_id: str) -> None:
"""Marks the room as blocked. Can be called multiple times.
Args:
- room_id (str): Room to block
- user_id (str): Who blocked it
-
- Returns:
- Deferred
+ room_id: Room to block
+ user_id: Who blocked it
"""
- yield self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
table="blocked_rooms",
keyvalues={"room_id": room_id},
values={},
insertion_values={"user_id": user_id},
desc="block_room",
)
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"block_room_invalidation",
self._invalidate_cache_and_stream,
self.is_room_blocked,
(room_id,),
)
- @defer.inlineCallbacks
- def get_rooms_for_retention_period_in_range(
- self, min_ms, max_ms, include_null=False
- ):
+ async def get_rooms_for_retention_period_in_range(
+ self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
+ ) -> Dict[str, dict]:
"""Retrieves all of the rooms within the given retention range.
Optionally includes the rooms which don't have a retention policy.
Args:
- min_ms (int|None): Duration in milliseconds that define the lower limit of
+ min_ms: Duration in milliseconds that define the lower limit of
the range to handle (exclusive). If None, doesn't set a lower limit.
- max_ms (int|None): Duration in milliseconds that define the upper limit of
+ max_ms: Duration in milliseconds that define the upper limit of
the range to handle (inclusive). If None, doesn't set an upper limit.
- include_null (bool): Whether to include rooms which retention policy is NULL
+ include_null: Whether to include rooms which retention policy is NULL
in the returned set.
Returns:
- dict[str, dict]: The rooms within this range, along with their retention
- policy. The key is "room_id", and maps to a dict describing the retention
- policy associated with this room ID. The keys for this nested dict are
- "min_lifetime" (int|None), and "max_lifetime" (int|None).
+ The rooms within this range, along with their retention
+ policy. The key is "room_id", and maps to a dict describing the retention
+ policy associated with this room ID. The keys for this nested dict are
+ "min_lifetime" (int|None), and "max_lifetime" (int|None).
"""
def get_rooms_for_retention_period_in_range_txn(txn):
@@ -1396,7 +1400,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
txn.execute(sql, args)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
rooms_dict = {}
for row in rows:
@@ -1412,7 +1416,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
txn.execute(sql)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
# If a room isn't already in the dict (i.e. it doesn't have a retention
# policy in its state), add it with a null policy.
@@ -1425,9 +1429,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return rooms_dict
- rooms = yield self.db.runInteraction(
+ rooms = await self.db_pool.runInteraction(
"get_rooms_for_retention_period_in_range",
get_rooms_for_retention_period_in_range_txn,
)
- defer.returnValue(rooms)
+ return rooms
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 137ebac833..91a8b43da3 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,24 +15,21 @@
# limitations under the License.
import logging
-from typing import Iterable, List, Set
-
-from six import iteritems, itervalues
-
-from canonicaljson import json
-
-from twisted.internet import defer
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
from synapse.api.constants import EventTypes, Membership
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import (
LoggingTransaction,
SQLBaseStore,
+ db_to_json,
make_in_list_sql_clause,
)
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine
from synapse.storage.roommember import (
GetRoomsForUserWithStreamOrdering,
@@ -43,9 +40,12 @@ from synapse.storage.roommember import (
from synapse.types import Collection, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.state import _StateCacheEntry
+
logger = logging.getLogger(__name__)
@@ -54,7 +54,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
# Is the current_state_events.membership up to date? Or is the
@@ -90,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
lambda: self._known_servers_count,
)
- @defer.inlineCallbacks
- def _count_known_servers(self):
+ async def _count_known_servers(self):
"""
Count the servers that this server knows about.
@@ -119,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(query)
return list(txn)[0][0]
- count = yield self.db.runInteraction("get_known_servers", _transact)
+ count = await self.db_pool.runInteraction("get_known_servers", _transact)
# We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new).
@@ -131,7 +130,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
membership column is up to date
"""
- pending_update = self.db.simple_select_one_txn(
+ pending_update = self.db_pool.simple_select_one_txn(
txn,
table="background_updates",
keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
@@ -147,18 +146,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
15.0,
run_as_background_process,
"_check_safe_current_state_events_membership_updated",
- self.db.runInteraction,
+ self.db_pool.runInteraction,
"_check_safe_current_state_events_membership_updated",
self._check_safe_current_state_events_membership_updated_txn,
)
@cached(max_entries=100000, iterable=True)
- def get_users_in_room(self, room_id):
- return self.db.runInteraction(
+ async def get_users_in_room(self, room_id: str) -> List[str]:
+ return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
)
- def get_users_in_room_txn(self, txn, room_id):
+ def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
# If we can assume current_state_events.membership is up to date
# then we can avoid a join, which is a Very Good Thing given how
# frequently this function gets called.
@@ -181,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [r[0] for r in txn]
@cached(max_entries=100000)
- def get_room_summary(self, room_id):
+ async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
""" Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
- room_id (str): The room ID to query
+ room_id: The room ID to query
Returns:
- Deferred[dict[str, MemberSummary]:
- dict of membership states, pointing to a MemberSummary named tuple.
+ dict of membership states, pointing to a MemberSummary named tuple.
"""
def _get_room_summary_txn(txn):
@@ -262,80 +260,63 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return res
- return self.db.runInteraction("get_room_summary", _get_room_summary_txn)
-
- def _get_user_counts_in_room_txn(self, txn, room_id):
- """
- Get the user count in a room by membership.
-
- Args:
- room_id (str)
- membership (Membership)
-
- Returns:
- Deferred[int]
- """
- sql = """
- SELECT m.membership, count(*) FROM room_memberships as m
- INNER JOIN current_state_events as c USING(event_id)
- WHERE c.type = 'm.room.member' AND c.room_id = ?
- GROUP BY m.membership
- """
-
- txn.execute(sql, (room_id,))
- return {row[0]: row[1] for row in txn}
+ return await self.db_pool.runInteraction(
+ "get_room_summary", _get_room_summary_txn
+ )
@cached()
- def get_invited_rooms_for_local_user(self, user_id):
- """ Get all the rooms the *local* user is invited to
+ async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
+ """Get all the rooms the *local* user is invited to.
Args:
- user_id (str): The user ID.
+ user_id: The user ID.
+
Returns:
- A deferred list of RoomsForUser.
+ A list of RoomsForUser.
"""
- return self.get_rooms_for_local_user_where_membership_is(
+ return await self.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE]
)
- @defer.inlineCallbacks
- def get_invite_for_local_user_in_room(self, user_id, room_id):
- """Gets the invite for the given *local* user and room
+ async def get_invite_for_local_user_in_room(
+ self, user_id: str, room_id: str
+ ) -> Optional[RoomsForUser]:
+ """Gets the invite for the given *local* user and room.
Args:
- user_id (str)
- room_id (str)
+ user_id: The user ID to find the invite of.
+ room_id: The room to user was invited to.
Returns:
- Deferred: Resolves to either a RoomsForUser or None if no invite was
- found.
+ Either a RoomsForUser or None if no invite was found.
"""
- invites = yield self.get_invited_rooms_for_local_user(user_id)
+ invites = await self.get_invited_rooms_for_local_user(user_id)
for invite in invites:
if invite.room_id == room_id:
return invite
return None
- @defer.inlineCallbacks
- def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
- """ Get all the rooms for this *local* user where the membership for this user
+ async def get_rooms_for_local_user_where_membership_is(
+ self, user_id: str, membership_list: Collection[str]
+ ) -> List[RoomsForUser]:
+ """Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
Filters out forgotten rooms.
Args:
- user_id (str): The user ID.
- membership_list (list): A list of synapse.api.constants.Membership
- values which the user must be in.
+ user_id: The user ID.
+ membership_list: A list of synapse.api.constants.Membership
+ values which the user must be in.
Returns:
- Deferred[list[RoomsForUser]]
+ The RoomsForUser that the user matches the membership types.
"""
if not membership_list:
- return defer.succeed(None)
+ return []
- rooms = yield self.db.runInteraction(
+ rooms = await self.db_pool.runInteraction(
"get_rooms_for_local_user_where_membership_is",
self._get_rooms_for_local_user_where_membership_is_txn,
user_id,
@@ -343,12 +324,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
# Now we filter out forgotten rooms
- forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
+ forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
return [room for room in rooms if room.room_id not in forgotten_rooms]
def _get_rooms_for_local_user_where_membership_is_txn(
- self, txn, user_id, membership_list
- ):
+ self, txn, user_id: str, membership_list: List[str]
+ ) -> List[RoomsForUser]:
# Paranoia check.
if not self.hs.is_mine_id(user_id):
raise Exception(
@@ -372,32 +353,35 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
txn.execute(sql, (user_id, *args))
- results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
+ results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)]
return results
@cached(max_entries=500000, iterable=True)
- def get_rooms_for_user_with_stream_ordering(self, user_id):
+ async def get_rooms_for_user_with_stream_ordering(
+ self, user_id: str
+ ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
participating in.
Args:
- user_id (str)
+ user_id
Returns:
- Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
- the rooms the user is in currently, along with the stream ordering
- of the most recent join for that user and room.
+ Returns the rooms the user is in currently, along with the stream
+ ordering of the most recent join for that user and room.
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_rooms_for_user_with_stream_ordering",
self._get_rooms_for_user_with_stream_ordering_txn,
user_id,
)
- def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
+ def _get_rooms_for_user_with_stream_ordering_txn(
+ self, txn, user_id: str
+ ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in.
@@ -424,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (user_id, Membership.JOIN))
- results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
-
- return results
+ return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
@@ -456,42 +438,44 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return {row[0] for row in txn}
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_server_still_shares_room_with",
_get_users_server_still_shares_room_with_txn,
)
- @defer.inlineCallbacks
- def get_rooms_for_user(self, user_id, on_invalidate=None):
+ async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
participating in.
"""
- rooms = yield self.get_rooms_for_user_with_stream_ordering(
+ rooms = await self.get_rooms_for_user_with_stream_ordering(
user_id, on_invalidate=on_invalidate
)
return frozenset(r.room_id for r in rooms)
- @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
- def get_users_who_share_room_with_user(self, user_id, cache_context):
+ @cached(max_entries=500000, cache_context=True, iterable=True)
+ async def get_users_who_share_room_with_user(
+ self, user_id: str, cache_context: _CacheContext
+ ) -> Set[str]:
"""Returns the set of users who share a room with `user_id`
"""
- room_ids = yield self.get_rooms_for_user(
+ room_ids = await self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate
)
user_who_share_room = set()
for room_id in room_ids:
- user_ids = yield self.get_users_in_room(
+ user_ids = await self.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
user_who_share_room.update(user_ids)
return user_who_share_room
- @defer.inlineCallbacks
- def get_joined_users_from_context(self, event, context):
+ async def get_joined_users_from_context(
+ self, event: EventBase, context: EventContext
+ ):
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -500,14 +484,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- current_state_ids = yield context.get_current_state_ids()
- result = yield self._get_joined_users_from_context(
+ current_state_ids = await context.get_current_state_ids()
+ return await self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
- return result
- @defer.inlineCallbacks
- def get_joined_users_from_state(self, room_id, state_entry):
+ async def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -517,16 +499,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
with Measure(self._clock, "get_joined_users_from_state"):
- return (
- yield self._get_joined_users_from_context(
- room_id, state_group, state_entry.state, context=state_entry
- )
+ return await self._get_joined_users_from_context(
+ room_id, state_group, state_entry.state, context=state_entry
)
- @cachedInlineCallbacks(
- num_args=2, cache_context=True, iterable=True, max_entries=100000
- )
- def _get_joined_users_from_context(
+ @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
+ async def _get_joined_users_from_context(
self,
room_id,
state_group,
@@ -538,13 +516,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
users_in_room = {}
member_event_ids = [
e_id
- for key, e_id in iteritems(current_state_ids)
+ for key, e_id in current_state_ids.items()
if key[0] == EventTypes.Member
]
@@ -561,7 +538,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
users_in_room = dict(prev_res)
member_event_ids = [
e_id
- for key, e_id in iteritems(context.delta_ids)
+ for key, e_id in context.delta_ids.items()
if key[0] == EventTypes.Member
]
for etype, state_key in context.delta_ids:
@@ -591,7 +568,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
missing_member_event_ids.append(event_id)
if missing_member_event_ids:
- event_to_memberships = yield self._get_joined_profiles_from_event_ids(
+ event_to_memberships = await self._get_joined_profiles_from_event_ids(
missing_member_event_ids
)
users_in_room.update((row for row in event_to_memberships.values() if row))
@@ -611,23 +588,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_joined_profile_from_event_id",
- list_name="event_ids",
- inlineCallbacks=True,
+ cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
)
- def _get_joined_profiles_from_event_ids(self, event_ids):
+ async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
Args:
- event_ids (Iterable[str]): The member event IDs to lookup
+ event_ids: The member event IDs to lookup
Returns:
- Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+ dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
@@ -647,8 +622,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
for row in rows
}
- @cachedInlineCallbacks(max_entries=10000)
- def is_host_joined(self, room_id, host):
+ @cached(max_entries=10000)
+ async def is_host_joined(self, room_id: str, host: str) -> bool:
if "%" in host or "_" in host:
raise Exception("Invalid host name")
@@ -667,47 +642,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# the returned user actually has the correct domain.
like_clause = "%:" + host
- rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause)
-
- if not rows:
- return False
-
- user_id = rows[0][0]
- if get_domain_from_id(user_id) != host:
- # This can only happen if the host name has something funky in it
- raise Exception("Invalid host name")
-
- return True
-
- @cachedInlineCallbacks()
- def was_host_joined(self, room_id, host):
- """Check whether the server is or ever was in the room.
-
- Args:
- room_id (str)
- host (str)
-
- Returns:
- Deferred: Resolves to True if the host is/was in the room, otherwise
- False.
- """
- if "%" in host or "_" in host:
- raise Exception("Invalid host name")
-
- sql = """
- SELECT user_id FROM room_memberships
- WHERE room_id = ?
- AND user_id LIKE ?
- AND membership = 'join'
- LIMIT 1
- """
-
- # We do need to be careful to ensure that host doesn't have any wild cards
- # in it, but we checked above for known ones and we'll check below that
- # the returned user actually has the correct domain.
- like_clause = "%:" + host
-
- rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause)
+ rows = await self.db_pool.execute(
+ "is_host_joined", None, sql, room_id, like_clause
+ )
if not rows:
return False
@@ -719,8 +656,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
- @defer.inlineCallbacks
- def get_joined_hosts(self, room_id, state_entry):
+ async def get_joined_hosts(self, room_id: str, state_entry):
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -730,32 +666,28 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
with Measure(self._clock, "get_joined_hosts"):
- return (
- yield self._get_joined_hosts(
- room_id, state_group, state_entry.state, state_entry=state_entry
- )
+ return await self._get_joined_hosts(
+ room_id, state_group, state_entry.state, state_entry=state_entry
)
- @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
- # @defer.inlineCallbacks
- def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
+ @cached(num_args=2, max_entries=10000, iterable=True)
+ async def _get_joined_hosts(
+ self, room_id, state_group, current_state_ids, state_entry
+ ):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
- cache = yield self._get_joined_hosts_cache(room_id)
- joined_hosts = yield cache.get_destinations(state_entry)
-
- return joined_hosts
+ cache = await self._get_joined_hosts_cache(room_id)
+ return await cache.get_destinations(state_entry)
@cached(max_entries=10000)
- def _get_joined_hosts_cache(self, room_id):
+ def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
return _JoinedHostsCache(self, room_id)
- @cachedInlineCallbacks(num_args=2)
- def did_forget(self, user_id, room_id):
+ @cached(num_args=2)
+ async def did_forget(self, user_id: str, room_id: str) -> bool:
"""Returns whether user_id has elected to discard history for room_id.
Returns False if they have since re-joined."""
@@ -777,18 +709,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rows = txn.fetchall()
return rows[0][0]
- count = yield self.db.runInteraction("did_forget_membership", f)
+ count = await self.db_pool.runInteraction("did_forget_membership", f)
return count == 0
@cached()
- def get_forgotten_rooms_for_user(self, user_id):
+ async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
"""Gets all rooms the user has forgotten.
Args:
- user_id (str)
+ user_id: The user ID to query the rooms of.
Returns:
- Deferred[set[str]]
+ The forgotten rooms.
"""
def _get_forgotten_rooms_for_user_txn(txn):
@@ -814,22 +746,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (user_id,))
return {row[0] for row in txn if row[1] == 0}
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
- @defer.inlineCallbacks
- def get_rooms_user_has_been_in(self, user_id):
+ async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
"""Get all rooms that the user has ever been in.
Args:
- user_id (str)
+ user_id: The user ID to get the rooms of.
Returns:
- Deferred[set[str]]: Set of room IDs.
+ Set of room IDs.
"""
- room_ids = yield self.db.simple_select_onecol(
+ room_ids = await self.db_pool.simple_select_onecol(
table="room_memberships",
keyvalues={"membership": Membership.JOIN, "user_id": user_id},
retcol="room_id",
@@ -838,13 +769,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
- def get_membership_from_event_ids(
+ async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs.
"""
- return self.db.simple_select_many_batch(
+ return await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
@@ -880,23 +811,23 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return bool(txn.fetchone())
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"is_local_host_in_room_ignoring_users",
_is_local_host_in_room_ignoring_users_txn,
)
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
self._background_current_state_membership,
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"room_membership_forgotten_idx",
index_name="room_memberships_user_room_forgotten",
table="room_memberships",
@@ -904,8 +835,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
where_clause="forgotten = 1",
)
- @defer.inlineCallbacks
- def _background_add_membership_profile(self, progress, batch_size):
+ async def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
"target_min_stream_id_inclusive", self._min_stream_order_on_start
)
@@ -929,7 +859,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if not rows:
return 0
@@ -940,7 +870,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
event_id = row["event_id"]
room_id = row["room_id"]
try:
- event_json = json.loads(row["json"])
+ event_json = db_to_json(row["json"])
content = event_json["content"]
except Exception:
continue
@@ -964,25 +894,24 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
"max_stream_id_exclusive": min_stream_id,
}
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
)
return len(rows)
- result = yield self.db.runInteraction(
+ result = await self.db_pool.runInteraction(
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
)
if not result:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
_MEMBERSHIP_PROFILE_UPDATE_NAME
)
return result
- @defer.inlineCallbacks
- def _background_current_state_membership(self, progress, batch_size):
+ async def _background_current_state_membership(self, progress, batch_size):
"""Update the new membership column on current_state_events.
This works by iterating over all rooms in alphebetical order.
@@ -1016,7 +945,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
last_processed_room = next_room
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn,
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
{"last_processed_room": last_processed_room},
@@ -1028,14 +957,14 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
# string, which will compare before all room IDs correctly.
last_processed_room = progress.get("last_processed_room", "")
- row_count, finished = yield self.db.runInteraction(
+ row_count, finished = await self.db_pool.runInteraction(
"_background_current_state_membership_update",
_background_current_state_membership_txn,
last_processed_room,
)
if finished:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
)
@@ -1043,10 +972,10 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs)
- def forget(self, user_id, room_id):
+ async def forget(self, user_id: str, room_id: str) -> None:
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@@ -1067,10 +996,10 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
- return self.db.runInteraction("forget_membership", f)
+ await self.db_pool.runInteraction("forget_membership", f)
-class _JoinedHostsCache(object):
+class _JoinedHostsCache:
"""Cache for joined hosts in a room that is optimised to handle updates
via state deltas.
"""
@@ -1087,21 +1016,23 @@ class _JoinedHostsCache(object):
self._len = 0
- @defer.inlineCallbacks
- def get_destinations(self, state_entry):
+ async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]:
"""Get set of destinations for a state entry
Args:
- state_entry(synapse.state._StateCacheEntry)
+ state_entry
+
+ Returns:
+ The destinations as a set.
"""
if state_entry.state_group == self.state_group:
return frozenset(self.hosts_to_joined_users)
- with (yield self.linearizer.queue(())):
+ with (await self.linearizer.queue(())):
if state_entry.state_group == self.state_group:
pass
elif state_entry.prev_group == self.state_group:
- for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
+ for (typ, state_key), event_id in state_entry.delta_ids.items():
if typ != EventTypes.Member:
continue
@@ -1109,7 +1040,7 @@ class _JoinedHostsCache(object):
user_id = state_key
known_joins = self.hosts_to_joined_users.setdefault(host, set())
- event = yield self.store.get_event(event_id)
+ event = await self.store.get_event(event_id)
if event.membership == Membership.JOIN:
known_joins.add(user_id)
else:
@@ -1118,7 +1049,7 @@ class _JoinedHostsCache(object):
if not known_joins:
self.hosts_to_joined_users.pop(host, None)
else:
- joined_users = yield self.store.get_joined_users_from_state(
+ joined_users = await self.store.get_joined_users_from_state(
self.room_id, state_entry
)
@@ -1131,7 +1062,7 @@ class _JoinedHostsCache(object):
self.state_group = state_entry.state_group
else:
self.state_group = object()
- self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
+ self._len = sum(len(v) for v in self.hosts_to_joined_users.values())
return frozenset(self.hosts_to_joined_users)
def __len__(self):
diff --git a/synapse/storage/data_stores/main/schema/delta/12/v12.sql b/synapse/storage/databases/main/schema/delta/12/v12.sql
index 5964c5aaac..5964c5aaac 100644
--- a/synapse/storage/data_stores/main/schema/delta/12/v12.sql
+++ b/synapse/storage/databases/main/schema/delta/12/v12.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/13/v13.sql b/synapse/storage/databases/main/schema/delta/13/v13.sql
index f8649e5d99..f8649e5d99 100644
--- a/synapse/storage/data_stores/main/schema/delta/13/v13.sql
+++ b/synapse/storage/databases/main/schema/delta/13/v13.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/14/v14.sql b/synapse/storage/databases/main/schema/delta/14/v14.sql
index a831920da6..a831920da6 100644
--- a/synapse/storage/data_stores/main/schema/delta/14/v14.sql
+++ b/synapse/storage/databases/main/schema/delta/14/v14.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql b/synapse/storage/databases/main/schema/delta/15/appservice_txns.sql
index e4f5e76aec..e4f5e76aec 100644
--- a/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql
+++ b/synapse/storage/databases/main/schema/delta/15/appservice_txns.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql b/synapse/storage/databases/main/schema/delta/15/presence_indices.sql
index 6b8d0f1ca7..6b8d0f1ca7 100644
--- a/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql
+++ b/synapse/storage/databases/main/schema/delta/15/presence_indices.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/15/v15.sql b/synapse/storage/databases/main/schema/delta/15/v15.sql
index 9523d2bcc3..9523d2bcc3 100644
--- a/synapse/storage/data_stores/main/schema/delta/15/v15.sql
+++ b/synapse/storage/databases/main/schema/delta/15/v15.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql b/synapse/storage/databases/main/schema/delta/16/events_order_index.sql
index a48f215170..a48f215170 100644
--- a/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql
+++ b/synapse/storage/databases/main/schema/delta/16/events_order_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/databases/main/schema/delta/16/remote_media_cache_index.sql
index 7a15265cb1..7a15265cb1 100644
--- a/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql
+++ b/synapse/storage/databases/main/schema/delta/16/remote_media_cache_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql b/synapse/storage/databases/main/schema/delta/16/remove_duplicates.sql
index 65c97b5e2f..65c97b5e2f 100644
--- a/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql
+++ b/synapse/storage/databases/main/schema/delta/16/remove_duplicates.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql b/synapse/storage/databases/main/schema/delta/16/room_alias_index.sql
index f82486132b..f82486132b 100644
--- a/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql
+++ b/synapse/storage/databases/main/schema/delta/16/room_alias_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql b/synapse/storage/databases/main/schema/delta/16/unique_constraints.sql
index 5b8de52c33..5b8de52c33 100644
--- a/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql
+++ b/synapse/storage/databases/main/schema/delta/16/unique_constraints.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/16/users.sql b/synapse/storage/databases/main/schema/delta/16/users.sql
index cd0709250d..cd0709250d 100644
--- a/synapse/storage/data_stores/main/schema/delta/16/users.sql
+++ b/synapse/storage/databases/main/schema/delta/16/users.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql b/synapse/storage/databases/main/schema/delta/17/drop_indexes.sql
index 7c9a90e27f..7c9a90e27f 100644
--- a/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql
+++ b/synapse/storage/databases/main/schema/delta/17/drop_indexes.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql b/synapse/storage/databases/main/schema/delta/17/server_keys.sql
index 70b247a06b..70b247a06b 100644
--- a/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql
+++ b/synapse/storage/databases/main/schema/delta/17/server_keys.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql b/synapse/storage/databases/main/schema/delta/17/user_threepids.sql
index c17715ac80..c17715ac80 100644
--- a/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql
+++ b/synapse/storage/databases/main/schema/delta/17/user_threepids.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql b/synapse/storage/databases/main/schema/delta/18/server_keys_bigger_ints.sql
index 6e0871c92b..6e0871c92b 100644
--- a/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql
+++ b/synapse/storage/databases/main/schema/delta/18/server_keys_bigger_ints.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/19/event_index.sql b/synapse/storage/databases/main/schema/delta/19/event_index.sql
index 18b97b4332..18b97b4332 100644
--- a/synapse/storage/data_stores/main/schema/delta/19/event_index.sql
+++ b/synapse/storage/databases/main/schema/delta/19/event_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/20/dummy.sql b/synapse/storage/databases/main/schema/delta/20/dummy.sql
index e0ac49d1ec..e0ac49d1ec 100644
--- a/synapse/storage/data_stores/main/schema/delta/20/dummy.sql
+++ b/synapse/storage/databases/main/schema/delta/20/dummy.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py
index 3edfcfd783..3edfcfd783 100644
--- a/synapse/storage/data_stores/main/schema/delta/20/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/20/pushers.py
diff --git a/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql b/synapse/storage/databases/main/schema/delta/21/end_to_end_keys.sql
index 4c2fb20b77..4c2fb20b77 100644
--- a/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql
+++ b/synapse/storage/databases/main/schema/delta/21/end_to_end_keys.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/21/receipts.sql b/synapse/storage/databases/main/schema/delta/21/receipts.sql
index d070845477..d070845477 100644
--- a/synapse/storage/data_stores/main/schema/delta/21/receipts.sql
+++ b/synapse/storage/databases/main/schema/delta/21/receipts.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql b/synapse/storage/databases/main/schema/delta/22/receipts_index.sql
index bfc0b3bcaa..bfc0b3bcaa 100644
--- a/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql
+++ b/synapse/storage/databases/main/schema/delta/22/receipts_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql b/synapse/storage/databases/main/schema/delta/22/user_threepids_unique.sql
index 87edfa454c..87edfa454c 100644
--- a/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql
+++ b/synapse/storage/databases/main/schema/delta/22/user_threepids_unique.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql b/synapse/storage/databases/main/schema/delta/24/stats_reporting.sql
index acea7483bd..acea7483bd 100644
--- a/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql
+++ b/synapse/storage/databases/main/schema/delta/24/stats_reporting.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py
index 4b2ffd35fd..ee675e71ff 100644
--- a/synapse/storage/data_stores/main/schema/delta/25/fts.py
+++ b/synapse/storage/databases/main/schema/delta/25/fts.py
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import logging
-import simplejson
-
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.prepare_database import get_statements
@@ -66,7 +64,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = simplejson.dumps(progress)
+ progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql b/synapse/storage/databases/main/schema/delta/25/guest_access.sql
index 1ea389b471..1ea389b471 100644
--- a/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql
+++ b/synapse/storage/databases/main/schema/delta/25/guest_access.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql b/synapse/storage/databases/main/schema/delta/25/history_visibility.sql
index f468fc1897..f468fc1897 100644
--- a/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql
+++ b/synapse/storage/databases/main/schema/delta/25/history_visibility.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/25/tags.sql b/synapse/storage/databases/main/schema/delta/25/tags.sql
index 7a32ce68e4..7a32ce68e4 100644
--- a/synapse/storage/data_stores/main/schema/delta/25/tags.sql
+++ b/synapse/storage/databases/main/schema/delta/25/tags.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/26/account_data.sql b/synapse/storage/databases/main/schema/delta/26/account_data.sql
index e395de2b5e..e395de2b5e 100644
--- a/synapse/storage/data_stores/main/schema/delta/26/account_data.sql
+++ b/synapse/storage/databases/main/schema/delta/26/account_data.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/27/account_data.sql b/synapse/storage/databases/main/schema/delta/27/account_data.sql
index bf0558b5b3..bf0558b5b3 100644
--- a/synapse/storage/data_stores/main/schema/delta/27/account_data.sql
+++ b/synapse/storage/databases/main/schema/delta/27/account_data.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql b/synapse/storage/databases/main/schema/delta/27/forgotten_memberships.sql
index e2094f37fe..e2094f37fe 100644
--- a/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql
+++ b/synapse/storage/databases/main/schema/delta/27/forgotten_memberships.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py
index 414f9f5aa0..b7972cfa8e 100644
--- a/synapse/storage/data_stores/main/schema/delta/27/ts.py
+++ b/synapse/storage/databases/main/schema/delta/27/ts.py
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import logging
-import simplejson
-
from synapse.storage.prepare_database import get_statements
logger = logging.getLogger(__name__)
@@ -45,7 +43,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = simplejson.dumps(progress)
+ progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql b/synapse/storage/databases/main/schema/delta/28/event_push_actions.sql
index 4d519849df..4d519849df 100644
--- a/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql
+++ b/synapse/storage/databases/main/schema/delta/28/event_push_actions.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql b/synapse/storage/databases/main/schema/delta/28/events_room_stream.sql
index 36609475f1..36609475f1 100644
--- a/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql
+++ b/synapse/storage/databases/main/schema/delta/28/events_room_stream.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql b/synapse/storage/databases/main/schema/delta/28/public_roms_index.sql
index 6c1fd68c5b..6c1fd68c5b 100644
--- a/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql
+++ b/synapse/storage/databases/main/schema/delta/28/public_roms_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/databases/main/schema/delta/28/receipts_user_id_index.sql
index cb84c69baa..cb84c69baa 100644
--- a/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql
+++ b/synapse/storage/databases/main/schema/delta/28/receipts_user_id_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql b/synapse/storage/databases/main/schema/delta/28/upgrade_times.sql
index 3e4a9ab455..3e4a9ab455 100644
--- a/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql
+++ b/synapse/storage/databases/main/schema/delta/28/upgrade_times.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql b/synapse/storage/databases/main/schema/delta/28/users_is_guest.sql
index 21d2b420bf..21d2b420bf 100644
--- a/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql
+++ b/synapse/storage/databases/main/schema/delta/28/users_is_guest.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql b/synapse/storage/databases/main/schema/delta/29/push_actions.sql
index 84b21cf813..84b21cf813 100644
--- a/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql
+++ b/synapse/storage/databases/main/schema/delta/29/push_actions.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql b/synapse/storage/databases/main/schema/delta/30/alias_creator.sql
index c9d0dde638..c9d0dde638 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql
+++ b/synapse/storage/databases/main/schema/delta/30/alias_creator.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py
index 9b95411fb6..b42c02710a 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/as_users.py
+++ b/synapse/storage/databases/main/schema/delta/30/as_users.py
@@ -13,8 +13,6 @@
# limitations under the License.
import logging
-from six.moves import range
-
from synapse.config.appservice import load_appservices
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql b/synapse/storage/databases/main/schema/delta/30/deleted_pushers.sql
index 712c454aa1..712c454aa1 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql
+++ b/synapse/storage/databases/main/schema/delta/30/deleted_pushers.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql b/synapse/storage/databases/main/schema/delta/30/presence_stream.sql
index 606bbb037d..606bbb037d 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql
+++ b/synapse/storage/databases/main/schema/delta/30/presence_stream.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql b/synapse/storage/databases/main/schema/delta/30/public_rooms.sql
index f09db4faa6..f09db4faa6 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql
+++ b/synapse/storage/databases/main/schema/delta/30/public_rooms.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql b/synapse/storage/databases/main/schema/delta/30/push_rule_stream.sql
index 735aa8d5f6..735aa8d5f6 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql
+++ b/synapse/storage/databases/main/schema/delta/30/push_rule_stream.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/databases/main/schema/delta/30/threepid_guest_access_tokens.sql
index 0dd2f1360c..0dd2f1360c 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql
+++ b/synapse/storage/databases/main/schema/delta/30/threepid_guest_access_tokens.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/31/invites.sql b/synapse/storage/databases/main/schema/delta/31/invites.sql
index 2c57846d5a..2c57846d5a 100644
--- a/synapse/storage/data_stores/main/schema/delta/31/invites.sql
+++ b/synapse/storage/databases/main/schema/delta/31/invites.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/databases/main/schema/delta/31/local_media_repository_url_cache.sql
index 9efb4280eb..9efb4280eb 100644
--- a/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql
+++ b/synapse/storage/databases/main/schema/delta/31/local_media_repository_url_cache.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py
index 9bb504aad5..9bb504aad5 100644
--- a/synapse/storage/data_stores/main/schema/delta/31/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/31/pushers.py
diff --git a/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql b/synapse/storage/databases/main/schema/delta/31/pushers_index.sql
index a82add88fd..a82add88fd 100644
--- a/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql
+++ b/synapse/storage/databases/main/schema/delta/31/pushers_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py
index 7d8ca5f93f..63b757ade6 100644
--- a/synapse/storage/data_stores/main/schema/delta/31/search_update.py
+++ b/synapse/storage/databases/main/schema/delta/31/search_update.py
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import logging
-import simplejson
-
from synapse.storage.engines import PostgresEngine
from synapse.storage.prepare_database import get_statements
@@ -50,7 +48,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"rows_inserted": 0,
"have_added_indexes": False,
}
- progress_json = simplejson.dumps(progress)
+ progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/data_stores/main/schema/delta/32/events.sql b/synapse/storage/databases/main/schema/delta/32/events.sql
index 1dd0f9e170..1dd0f9e170 100644
--- a/synapse/storage/data_stores/main/schema/delta/32/events.sql
+++ b/synapse/storage/databases/main/schema/delta/32/events.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/32/openid.sql b/synapse/storage/databases/main/schema/delta/32/openid.sql
index 36f37b11c8..36f37b11c8 100644
--- a/synapse/storage/data_stores/main/schema/delta/32/openid.sql
+++ b/synapse/storage/databases/main/schema/delta/32/openid.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql b/synapse/storage/databases/main/schema/delta/32/pusher_throttle.sql
index d86d30c13c..d86d30c13c 100644
--- a/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql
+++ b/synapse/storage/databases/main/schema/delta/32/pusher_throttle.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql b/synapse/storage/databases/main/schema/delta/32/remove_indices.sql
index 2de50d408c..2de50d408c 100644
--- a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
+++ b/synapse/storage/databases/main/schema/delta/32/remove_indices.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/32/reports.sql b/synapse/storage/databases/main/schema/delta/32/reports.sql
index d13609776f..d13609776f 100644
--- a/synapse/storage/data_stores/main/schema/delta/32/reports.sql
+++ b/synapse/storage/databases/main/schema/delta/32/reports.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/databases/main/schema/delta/33/access_tokens_device_index.sql
index 61ad3fe3e8..61ad3fe3e8 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql
+++ b/synapse/storage/databases/main/schema/delta/33/access_tokens_device_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices.sql b/synapse/storage/databases/main/schema/delta/33/devices.sql
index eca7268d82..eca7268d82 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/devices.sql
+++ b/synapse/storage/databases/main/schema/delta/33/devices.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys.sql
index aa4a3b9f2f..aa4a3b9f2f 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql
+++ b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
index 6671573398..6671573398 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
+++ b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py
index bff1256a7b..a3e81eeac7 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py
+++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import json
import logging
-import simplejson
-
from synapse.storage.prepare_database import get_statements
logger = logging.getLogger(__name__)
@@ -45,7 +43,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = simplejson.dumps(progress)
+ progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index a26057dfb6..a26057dfb6 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
diff --git a/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql b/synapse/storage/databases/main/schema/delta/33/user_ips_index.sql
index 473f75a78e..473f75a78e 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql
+++ b/synapse/storage/databases/main/schema/delta/33/user_ips_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql b/synapse/storage/databases/main/schema/delta/34/appservice_stream.sql
index 69e16eda0f..69e16eda0f 100644
--- a/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql
+++ b/synapse/storage/databases/main/schema/delta/34/appservice_stream.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py b/synapse/storage/databases/main/schema/delta/34/cache_stream.py
index cf09e43e2b..cf09e43e2b 100644
--- a/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py
+++ b/synapse/storage/databases/main/schema/delta/34/cache_stream.py
diff --git a/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql b/synapse/storage/databases/main/schema/delta/34/device_inbox.sql
index e68844c74a..e68844c74a 100644
--- a/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql
+++ b/synapse/storage/databases/main/schema/delta/34/device_inbox.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql b/synapse/storage/databases/main/schema/delta/34/push_display_name_rename.sql
index 0d9fe1a99a..0d9fe1a99a 100644
--- a/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql
+++ b/synapse/storage/databases/main/schema/delta/34/push_display_name_rename.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py b/synapse/storage/databases/main/schema/delta/34/received_txn_purge.py
index 67d505e68b..67d505e68b 100644
--- a/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py
+++ b/synapse/storage/databases/main/schema/delta/34/received_txn_purge.py
diff --git a/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql b/synapse/storage/databases/main/schema/delta/35/contains_url.sql
index 6cd123027b..6cd123027b 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql
+++ b/synapse/storage/databases/main/schema/delta/35/contains_url.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql b/synapse/storage/databases/main/schema/delta/35/device_outbox.sql
index 17e6c43105..17e6c43105 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql
+++ b/synapse/storage/databases/main/schema/delta/35/device_outbox.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql b/synapse/storage/databases/main/schema/delta/35/device_stream_id.sql
index 7ab7d942e2..7ab7d942e2 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql
+++ b/synapse/storage/databases/main/schema/delta/35/device_stream_id.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql b/synapse/storage/databases/main/schema/delta/35/event_push_actions_index.sql
index 2e836d8e9c..2e836d8e9c 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql
+++ b/synapse/storage/databases/main/schema/delta/35/event_push_actions_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/databases/main/schema/delta/35/public_room_list_change_stream.sql
index dd2bf2e28a..dd2bf2e28a 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql
+++ b/synapse/storage/databases/main/schema/delta/35/public_room_list_change_stream.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/databases/main/schema/delta/35/stream_order_to_extrem.sql
index 2b945d8a57..2b945d8a57 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql
+++ b/synapse/storage/databases/main/schema/delta/35/stream_order_to_extrem.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql b/synapse/storage/databases/main/schema/delta/36/readd_public_rooms.sql
index 90d8fd18f9..90d8fd18f9 100644
--- a/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql
+++ b/synapse/storage/databases/main/schema/delta/36/readd_public_rooms.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py b/synapse/storage/databases/main/schema/delta/37/remove_auth_idx.py
index a377884169..a377884169 100644
--- a/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py
+++ b/synapse/storage/databases/main/schema/delta/37/remove_auth_idx.py
diff --git a/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql b/synapse/storage/databases/main/schema/delta/37/user_threepids.sql
index cf7a90dd10..cf7a90dd10 100644
--- a/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql
+++ b/synapse/storage/databases/main/schema/delta/37/user_threepids.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/databases/main/schema/delta/38/postgres_fts_gist.sql
index 515e6b8e84..515e6b8e84 100644
--- a/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql
+++ b/synapse/storage/databases/main/schema/delta/38/postgres_fts_gist.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql b/synapse/storage/databases/main/schema/delta/39/appservice_room_list.sql
index 74bdc49073..74bdc49073 100644
--- a/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql
+++ b/synapse/storage/databases/main/schema/delta/39/appservice_room_list.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/databases/main/schema/delta/39/device_federation_stream_idx.sql
index 00be801e90..00be801e90 100644
--- a/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/39/device_federation_stream_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql b/synapse/storage/databases/main/schema/delta/39/event_push_index.sql
index de2ad93e5c..de2ad93e5c 100644
--- a/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql
+++ b/synapse/storage/databases/main/schema/delta/39/event_push_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql b/synapse/storage/databases/main/schema/delta/39/federation_out_position.sql
index 5af814290b..5af814290b 100644
--- a/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql
+++ b/synapse/storage/databases/main/schema/delta/39/federation_out_position.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql b/synapse/storage/databases/main/schema/delta/39/membership_profile.sql
index 1bf911c8ab..1bf911c8ab 100644
--- a/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql
+++ b/synapse/storage/databases/main/schema/delta/39/membership_profile.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql b/synapse/storage/databases/main/schema/delta/40/current_state_idx.sql
index 7ffa189f39..7ffa189f39 100644
--- a/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/40/current_state_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql b/synapse/storage/databases/main/schema/delta/40/device_inbox.sql
index b9fe1f0480..b9fe1f0480 100644
--- a/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql
+++ b/synapse/storage/databases/main/schema/delta/40/device_inbox.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql b/synapse/storage/databases/main/schema/delta/40/device_list_streams.sql
index dd6dcb65f1..dd6dcb65f1 100644
--- a/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql
+++ b/synapse/storage/databases/main/schema/delta/40/device_list_streams.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql b/synapse/storage/databases/main/schema/delta/40/event_push_summary.sql
index 3918f0b794..3918f0b794 100644
--- a/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql
+++ b/synapse/storage/databases/main/schema/delta/40/event_push_summary.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/40/pushers.sql b/synapse/storage/databases/main/schema/delta/40/pushers.sql
index 054a223f14..054a223f14 100644
--- a/synapse/storage/data_stores/main/schema/delta/40/pushers.sql
+++ b/synapse/storage/databases/main/schema/delta/40/pushers.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/databases/main/schema/delta/41/device_list_stream_idx.sql
index b7bee8b692..b7bee8b692 100644
--- a/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/41/device_list_stream_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql b/synapse/storage/databases/main/schema/delta/41/device_outbound_index.sql
index 62f0b9892b..62f0b9892b 100644
--- a/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql
+++ b/synapse/storage/databases/main/schema/delta/41/device_outbound_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/databases/main/schema/delta/41/event_search_event_id_idx.sql
index 5d9cfecf36..5d9cfecf36 100644
--- a/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/41/event_search_event_id_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql b/synapse/storage/databases/main/schema/delta/41/ratelimit.sql
index a194bf0238..a194bf0238 100644
--- a/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql
+++ b/synapse/storage/databases/main/schema/delta/41/ratelimit.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql b/synapse/storage/databases/main/schema/delta/42/current_state_delta.sql
index d28851aff8..d28851aff8 100644
--- a/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql
+++ b/synapse/storage/databases/main/schema/delta/42/current_state_delta.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql b/synapse/storage/databases/main/schema/delta/42/device_list_last_id.sql
index 9ab8c14fa3..9ab8c14fa3 100644
--- a/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql
+++ b/synapse/storage/databases/main/schema/delta/42/device_list_last_id.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql b/synapse/storage/databases/main/schema/delta/42/event_auth_state_only.sql
index b8821ac759..b8821ac759 100644
--- a/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql
+++ b/synapse/storage/databases/main/schema/delta/42/event_auth_state_only.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/42/user_dir.py b/synapse/storage/databases/main/schema/delta/42/user_dir.py
index 506f326f4d..506f326f4d 100644
--- a/synapse/storage/data_stores/main/schema/delta/42/user_dir.py
+++ b/synapse/storage/databases/main/schema/delta/42/user_dir.py
diff --git a/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql b/synapse/storage/databases/main/schema/delta/43/blocked_rooms.sql
index 0e3cd143ff..0e3cd143ff 100644
--- a/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql
+++ b/synapse/storage/databases/main/schema/delta/43/blocked_rooms.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql b/synapse/storage/databases/main/schema/delta/43/quarantine_media.sql
index 630907ec4f..630907ec4f 100644
--- a/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql
+++ b/synapse/storage/databases/main/schema/delta/43/quarantine_media.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql b/synapse/storage/databases/main/schema/delta/43/url_cache.sql
index 45ebe020da..45ebe020da 100644
--- a/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql
+++ b/synapse/storage/databases/main/schema/delta/43/url_cache.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/43/user_share.sql b/synapse/storage/databases/main/schema/delta/43/user_share.sql
index ee7062abe4..ee7062abe4 100644
--- a/synapse/storage/data_stores/main/schema/delta/43/user_share.sql
+++ b/synapse/storage/databases/main/schema/delta/43/user_share.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql b/synapse/storage/databases/main/schema/delta/44/expire_url_cache.sql
index b12f9b2ebf..b12f9b2ebf 100644
--- a/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql
+++ b/synapse/storage/databases/main/schema/delta/44/expire_url_cache.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/45/group_server.sql b/synapse/storage/databases/main/schema/delta/45/group_server.sql
index b2333848a0..b2333848a0 100644
--- a/synapse/storage/data_stores/main/schema/delta/45/group_server.sql
+++ b/synapse/storage/databases/main/schema/delta/45/group_server.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql b/synapse/storage/databases/main/schema/delta/45/profile_cache.sql
index e5ddc84df0..e5ddc84df0 100644
--- a/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql
+++ b/synapse/storage/databases/main/schema/delta/45/profile_cache.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/databases/main/schema/delta/46/drop_refresh_tokens.sql
index 68c48a89a9..68c48a89a9 100644
--- a/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql
+++ b/synapse/storage/databases/main/schema/delta/46/drop_refresh_tokens.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/databases/main/schema/delta/46/drop_unique_deleted_pushers.sql
index bb307889c1..bb307889c1 100644
--- a/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql
+++ b/synapse/storage/databases/main/schema/delta/46/drop_unique_deleted_pushers.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/46/group_server.sql b/synapse/storage/databases/main/schema/delta/46/group_server.sql
index 097679bc9a..097679bc9a 100644
--- a/synapse/storage/data_stores/main/schema/delta/46/group_server.sql
+++ b/synapse/storage/databases/main/schema/delta/46/group_server.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/databases/main/schema/delta/46/local_media_repository_url_idx.sql
index bbfc7f5d1a..bbfc7f5d1a 100644
--- a/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/46/local_media_repository_url_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/databases/main/schema/delta/46/user_dir_null_room_ids.sql
index cb0d5a2576..cb0d5a2576 100644
--- a/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql
+++ b/synapse/storage/databases/main/schema/delta/46/user_dir_null_room_ids.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql b/synapse/storage/databases/main/schema/delta/46/user_dir_typos.sql
index d9505f8da1..d9505f8da1 100644
--- a/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql
+++ b/synapse/storage/databases/main/schema/delta/46/user_dir_typos.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql b/synapse/storage/databases/main/schema/delta/47/last_access_media.sql
index f505fb22b5..f505fb22b5 100644
--- a/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql
+++ b/synapse/storage/databases/main/schema/delta/47/last_access_media.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/databases/main/schema/delta/47/postgres_fts_gin.sql
index 31d7a817eb..31d7a817eb 100644
--- a/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql
+++ b/synapse/storage/databases/main/schema/delta/47/postgres_fts_gin.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql b/synapse/storage/databases/main/schema/delta/47/push_actions_staging.sql
index edccf4a96f..edccf4a96f 100644
--- a/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql
+++ b/synapse/storage/databases/main/schema/delta/47/push_actions_staging.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql b/synapse/storage/databases/main/schema/delta/48/add_user_consent.sql
index 5237491506..5237491506 100644
--- a/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql
+++ b/synapse/storage/databases/main/schema/delta/48/add_user_consent.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/databases/main/schema/delta/48/add_user_ips_last_seen_index.sql
index 9248b0b24a..9248b0b24a 100644
--- a/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql
+++ b/synapse/storage/databases/main/schema/delta/48/add_user_ips_last_seen_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql b/synapse/storage/databases/main/schema/delta/48/deactivated_users.sql
index e9013a6969..e9013a6969 100644
--- a/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql
+++ b/synapse/storage/databases/main/schema/delta/48/deactivated_users.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py b/synapse/storage/databases/main/schema/delta/48/group_unique_indexes.py
index 49f5f2c003..49f5f2c003 100644
--- a/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py
+++ b/synapse/storage/databases/main/schema/delta/48/group_unique_indexes.py
diff --git a/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql b/synapse/storage/databases/main/schema/delta/48/groups_joinable.sql
index ce26eaf0c9..ce26eaf0c9 100644
--- a/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql
+++ b/synapse/storage/databases/main/schema/delta/48/groups_joinable.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/databases/main/schema/delta/49/add_user_consent_server_notice_sent.sql
index 14dcf18d73..14dcf18d73 100644
--- a/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql
+++ b/synapse/storage/databases/main/schema/delta/49/add_user_consent_server_notice_sent.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/databases/main/schema/delta/49/add_user_daily_visits.sql
index 3dd478196f..3dd478196f 100644
--- a/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql
+++ b/synapse/storage/databases/main/schema/delta/49/add_user_daily_visits.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/databases/main/schema/delta/49/add_user_ips_last_seen_only_index.sql
index 3a4ed59b5b..3a4ed59b5b 100644
--- a/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql
+++ b/synapse/storage/databases/main/schema/delta/49/add_user_ips_last_seen_only_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/databases/main/schema/delta/50/add_creation_ts_users_index.sql
index c93ae47532..c93ae47532 100644
--- a/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql
+++ b/synapse/storage/databases/main/schema/delta/50/add_creation_ts_users_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql b/synapse/storage/databases/main/schema/delta/50/erasure_store.sql
index 5d8641a9ab..5d8641a9ab 100644
--- a/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql
+++ b/synapse/storage/databases/main/schema/delta/50/erasure_store.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py b/synapse/storage/databases/main/schema/delta/50/make_event_content_nullable.py
index b1684a8441..b1684a8441 100644
--- a/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py
+++ b/synapse/storage/databases/main/schema/delta/50/make_event_content_nullable.py
diff --git a/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql b/synapse/storage/databases/main/schema/delta/51/e2e_room_keys.sql
index c0e66a697d..c0e66a697d 100644
--- a/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql
+++ b/synapse/storage/databases/main/schema/delta/51/e2e_room_keys.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql b/synapse/storage/databases/main/schema/delta/51/monthly_active_users.sql
index c9d537d5a3..c9d537d5a3 100644
--- a/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql
+++ b/synapse/storage/databases/main/schema/delta/51/monthly_active_users.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/databases/main/schema/delta/52/add_event_to_state_group_index.sql
index 91e03d13e1..91e03d13e1 100644
--- a/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql
+++ b/synapse/storage/databases/main/schema/delta/52/add_event_to_state_group_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/databases/main/schema/delta/52/device_list_streams_unique_idx.sql
index bfa49e6f92..bfa49e6f92 100644
--- a/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/52/device_list_streams_unique_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql b/synapse/storage/databases/main/schema/delta/52/e2e_room_keys.sql
index db687cccae..db687cccae 100644
--- a/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql
+++ b/synapse/storage/databases/main/schema/delta/52/e2e_room_keys.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql b/synapse/storage/databases/main/schema/delta/53/add_user_type_to_users.sql
index 88ec2f83e5..88ec2f83e5 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql
+++ b/synapse/storage/databases/main/schema/delta/53/add_user_type_to_users.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql b/synapse/storage/databases/main/schema/delta/53/drop_sent_transactions.sql
index e372f5a44a..e372f5a44a 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql
+++ b/synapse/storage/databases/main/schema/delta/53/drop_sent_transactions.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql b/synapse/storage/databases/main/schema/delta/53/event_format_version.sql
index 1d977c2834..1d977c2834 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql
+++ b/synapse/storage/databases/main/schema/delta/53/event_format_version.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql b/synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql
index ffcc896b58..ffcc896b58 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql
+++ b/synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql b/synapse/storage/databases/main/schema/delta/53/user_ips_index.sql
index b812c5794f..b812c5794f 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql
+++ b/synapse/storage/databases/main/schema/delta/53/user_ips_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_share.sql b/synapse/storage/databases/main/schema/delta/53/user_share.sql
index 5831b1a6f8..5831b1a6f8 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/user_share.sql
+++ b/synapse/storage/databases/main/schema/delta/53/user_share.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql b/synapse/storage/databases/main/schema/delta/53/user_threepid_id.sql
index 80c2c573b6..80c2c573b6 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql
+++ b/synapse/storage/databases/main/schema/delta/53/user_threepid_id.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/databases/main/schema/delta/53/users_in_public_rooms.sql
index f7827ca6d2..f7827ca6d2 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql
+++ b/synapse/storage/databases/main/schema/delta/53/users_in_public_rooms.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql b/synapse/storage/databases/main/schema/delta/54/account_validity_with_renewal.sql
index 0adb2ad55e..0adb2ad55e 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql
+++ b/synapse/storage/databases/main/schema/delta/54/account_validity_with_renewal.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/databases/main/schema/delta/54/add_validity_to_server_keys.sql
index c01aa9d2d9..c01aa9d2d9 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql
+++ b/synapse/storage/databases/main/schema/delta/54/add_validity_to_server_keys.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/databases/main/schema/delta/54/delete_forward_extremities.sql
index b062ec840c..b062ec840c 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql
+++ b/synapse/storage/databases/main/schema/delta/54/delete_forward_extremities.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/databases/main/schema/delta/54/drop_legacy_tables.sql
index dbbe682697..dbbe682697 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql
+++ b/synapse/storage/databases/main/schema/delta/54/drop_legacy_tables.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql b/synapse/storage/databases/main/schema/delta/54/drop_presence_list.sql
index e6ee70c623..e6ee70c623 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql
+++ b/synapse/storage/databases/main/schema/delta/54/drop_presence_list.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/relations.sql b/synapse/storage/databases/main/schema/delta/54/relations.sql
index 134862b870..134862b870 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/relations.sql
+++ b/synapse/storage/databases/main/schema/delta/54/relations.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/stats.sql b/synapse/storage/databases/main/schema/delta/54/stats.sql
index 652e58308e..652e58308e 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/stats.sql
+++ b/synapse/storage/databases/main/schema/delta/54/stats.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/stats2.sql b/synapse/storage/databases/main/schema/delta/54/stats2.sql
index 3b2d48447f..3b2d48447f 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/stats2.sql
+++ b/synapse/storage/databases/main/schema/delta/54/stats2.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql b/synapse/storage/databases/main/schema/delta/55/access_token_expiry.sql
index 4590604bfd..4590604bfd 100644
--- a/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql
+++ b/synapse/storage/databases/main/schema/delta/55/access_token_expiry.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql b/synapse/storage/databases/main/schema/delta/55/track_threepid_validations.sql
index a8eced2e0a..a8eced2e0a 100644
--- a/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql
+++ b/synapse/storage/databases/main/schema/delta/55/track_threepid_validations.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/databases/main/schema/delta/55/users_alter_deactivated.sql
index dabdde489b..dabdde489b 100644
--- a/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql
+++ b/synapse/storage/databases/main/schema/delta/55/users_alter_deactivated.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql b/synapse/storage/databases/main/schema/delta/56/add_spans_to_device_lists.sql
index 41807eb1e7..41807eb1e7 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql
+++ b/synapse/storage/databases/main/schema/delta/56/add_spans_to_device_lists.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership.sql
index 473018676f..473018676f 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql
+++ b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership_mk2.sql
index 3133d42d4a..3133d42d4a 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql
+++ b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership_mk2.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/databases/main/schema/delta/56/delete_keys_from_deleted_backups.sql
index 1d2ddb1b1a..1d2ddb1b1a 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql
+++ b/synapse/storage/databases/main/schema/delta/56/delete_keys_from_deleted_backups.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/databases/main/schema/delta/56/destinations_failure_ts.sql
index f00889290b..f00889290b 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql
+++ b/synapse/storage/databases/main/schema/delta/56/destinations_failure_ts.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/databases/main/schema/delta/56/destinations_retry_interval_type.sql.postgres
index b9bbb18a91..b9bbb18a91 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/56/destinations_retry_interval_type.sql.postgres
diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/databases/main/schema/delta/56/device_stream_id_insert.sql
index c2f557fde9..c2f557fde9 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql
+++ b/synapse/storage/databases/main/schema/delta/56/device_stream_id_insert.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql b/synapse/storage/databases/main/schema/delta/56/devices_last_seen.sql
index dfa902d0ba..dfa902d0ba 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql
+++ b/synapse/storage/databases/main/schema/delta/56/devices_last_seen.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql b/synapse/storage/databases/main/schema/delta/56/drop_unused_event_tables.sql
index 9f09922c67..9f09922c67 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql
+++ b/synapse/storage/databases/main/schema/delta/56/drop_unused_event_tables.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/databases/main/schema/delta/56/event_expiry.sql
index 81a36a8b1d..81a36a8b1d 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
+++ b/synapse/storage/databases/main/schema/delta/56/event_expiry.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
index 5e29c1da19..5e29c1da19 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
+++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql b/synapse/storage/databases/main/schema/delta/56/event_labels_background_update.sql
index 5f5e0499ae..5f5e0499ae 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql
+++ b/synapse/storage/databases/main/schema/delta/56/event_labels_background_update.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql b/synapse/storage/databases/main/schema/delta/56/fix_room_keys_index.sql
index 014cb3b538..014cb3b538 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql
+++ b/synapse/storage/databases/main/schema/delta/56/fix_room_keys_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql b/synapse/storage/databases/main/schema/delta/56/hidden_devices.sql
index 67f8b20297..67f8b20297 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql
+++ b/synapse/storage/databases/main/schema/delta/56/hidden_devices.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/databases/main/schema/delta/56/hidden_devices_fix.sql.sqlite
index e8b1fd35d8..e8b1fd35d8 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
+++ b/synapse/storage/databases/main/schema/delta/56/hidden_devices_fix.sql.sqlite
diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/databases/main/schema/delta/56/nuke_empty_communities_from_db.sql
index 4f24c1405d..4f24c1405d 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql
+++ b/synapse/storage/databases/main/schema/delta/56/nuke_empty_communities_from_db.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql b/synapse/storage/databases/main/schema/delta/56/public_room_list_idx.sql
index 7be31ffebb..7be31ffebb 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/56/public_room_list_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/databases/main/schema/delta/56/redaction_censor.sql
index ea95db0ed7..ea95db0ed7 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
+++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/databases/main/schema/delta/56/redaction_censor2.sql
index 49ce35d794..49ce35d794 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
+++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor2.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres b/synapse/storage/databases/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres
index 67471f3ef5..67471f3ef5 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/databases/main/schema/delta/56/redaction_censor4.sql
index b7550f6f4e..b7550f6f4e 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
+++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor4.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql b/synapse/storage/databases/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql
index aeb17813d3..aeb17813d3 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql
+++ b/synapse/storage/databases/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/databases/main/schema/delta/56/room_key_etag.sql
index 7d70dd071e..7d70dd071e 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql
+++ b/synapse/storage/databases/main/schema/delta/56/room_key_etag.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql b/synapse/storage/databases/main/schema/delta/56/room_membership_idx.sql
index 92ab1f5e65..92ab1f5e65 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/56/room_membership_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/databases/main/schema/delta/56/room_retention.sql
index ee6cdf7a14..ee6cdf7a14 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql
+++ b/synapse/storage/databases/main/schema/delta/56/room_retention.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql b/synapse/storage/databases/main/schema/delta/56/signing_keys.sql
index 5c5fffcafb..5c5fffcafb 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql
+++ b/synapse/storage/databases/main/schema/delta/56/signing_keys.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql b/synapse/storage/databases/main/schema/delta/56/signing_keys_nonunique_signatures.sql
index 0aa90ebf0c..0aa90ebf0c 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql
+++ b/synapse/storage/databases/main/schema/delta/56/signing_keys_nonunique_signatures.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql b/synapse/storage/databases/main/schema/delta/56/stats_separated.sql
index bbdde121e8..bbdde121e8 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
+++ b/synapse/storage/databases/main/schema/delta/56/stats_separated.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
index 1de8b54961..1de8b54961 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py
+++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
diff --git a/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql b/synapse/storage/databases/main/schema/delta/56/user_external_ids.sql
index 91390c4527..91390c4527 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql
+++ b/synapse/storage/databases/main/schema/delta/56/user_external_ids.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql b/synapse/storage/databases/main/schema/delta/56/users_in_public_rooms_idx.sql
index 149f8be8b6..149f8be8b6 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/56/users_in_public_rooms_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql b/synapse/storage/databases/main/schema/delta/57/delete_old_current_state_events.sql
index aec06c8261..aec06c8261 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql
+++ b/synapse/storage/databases/main/schema/delta/57/delete_old_current_state_events.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql b/synapse/storage/databases/main/schema/delta/57/device_list_remote_cache_stale.sql
index c3b6de2099..c3b6de2099 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql
+++ b/synapse/storage/databases/main/schema/delta/57/device_list_remote_cache_stale.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
index 63b5acdcf7..63b5acdcf7 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
+++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
diff --git a/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql b/synapse/storage/databases/main/schema/delta/57/remove_sent_outbound_pokes.sql
index 133d80af35..133d80af35 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql
+++ b/synapse/storage/databases/main/schema/delta/57/remove_sent_outbound_pokes.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql b/synapse/storage/databases/main/schema/delta/57/rooms_version_column.sql
index 352a66f5b0..352a66f5b0 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql
+++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.postgres
index c601cff6de..c601cff6de 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.postgres
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.sqlite
index 335c6f2074..335c6f2074 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite
+++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.sqlite
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.postgres
index 92aaadde0d..92aaadde0d 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.postgres
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.sqlite
index e19dab97cb..e19dab97cb 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite
+++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.sqlite
diff --git a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql b/synapse/storage/databases/main/schema/delta/58/02remove_dup_outbound_pokes.sql
index fdc39e9ba5..fdc39e9ba5 100644
--- a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql
+++ b/synapse/storage/databases/main/schema/delta/58/02remove_dup_outbound_pokes.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql b/synapse/storage/databases/main/schema/delta/58/03persist_ui_auth.sql
index dcb593fc2d..dcb593fc2d 100644
--- a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql
+++ b/synapse/storage/databases/main/schema/delta/58/03persist_ui_auth.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/databases/main/schema/delta/58/05cache_instance.sql.postgres
index aa46eb0e10..aa46eb0e10 100644
--- a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/58/05cache_instance.sql.postgres
diff --git a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py b/synapse/storage/databases/main/schema/delta/58/06dlols_unique_idx.py
index d353f2bcb3..d353f2bcb3 100644
--- a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py
+++ b/synapse/storage/databases/main/schema/delta/58/06dlols_unique_idx.py
diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
new file mode 100644
index 0000000000..4cc96a5341
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A table of the IP address and user-agent used to complete each step of a
+-- user-interactive authentication session.
+CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips(
+ session_id TEXT NOT NULL,
+ ip TEXT NOT NULL,
+ user_agent TEXT NOT NULL,
+ UNIQUE (session_id, ip, user_agent),
+ FOREIGN KEY (session_id)
+ REFERENCES ui_auth_sessions (session_id)
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres
new file mode 100644
index 0000000000..597f2ffd3d
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- The local_media_repository should have files which do not get quarantined,
+-- e.g. files from sticker packs.
+ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT FALSE;
diff --git a/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite
new file mode 100644
index 0000000000..69db89ac0e
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- The local_media_repository should have files which do not get quarantined,
+-- e.g. files from sticker packs.
+ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT 0;
diff --git a/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql
new file mode 100644
index 0000000000..260b009b48
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A shadow-banned user may be told that their requests succeeded when they were
+-- actually ignored.
+ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN;
diff --git a/synapse/storage/databases/main/schema/delta/58/10drop_local_rejections_stream.sql b/synapse/storage/databases/main/schema/delta/58/10drop_local_rejections_stream.sql
new file mode 100644
index 0000000000..eb57203e46
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/10drop_local_rejections_stream.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+The version of synapse 1.16.0 on pypi incorrectly contained a migration which
+added a table called 'local_rejections_stream'. This table is not used, and
+we drop it here for anyone who was affected.
+*/
+
+DROP TABLE IF EXISTS local_rejections_stream;
diff --git a/synapse/storage/databases/main/schema/delta/58/10federation_pos_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/10federation_pos_instance_name.sql
new file mode 100644
index 0000000000..1cc2633aad
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/10federation_pos_instance_name.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We need to store the stream positions by instance in a sharded config world.
+--
+-- We default to master as we want the column to be NOT NULL and we correctly
+-- reset the instance name to match the config each time we start up.
+ALTER TABLE federation_stream_position ADD COLUMN instance_name TEXT NOT NULL DEFAULT 'master';
+
+CREATE UNIQUE INDEX federation_stream_position_instance ON federation_stream_position(type, instance_name);
diff --git a/synapse/storage/databases/main/schema/delta/58/11user_id_seq.py b/synapse/storage/databases/main/schema/delta/58/11user_id_seq.py
new file mode 100644
index 0000000000..4310ec12ce
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11user_id_seq.py
@@ -0,0 +1,34 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Adds a postgres SEQUENCE for generating guest user IDs.
+"""
+
+from synapse.storage.databases.main.registration import (
+ find_max_generated_user_id_localpart,
+)
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if not isinstance(database_engine, PostgresEngine):
+ return
+
+ next_id = find_max_generated_user_id_localpart(cur) + 1
+ cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,))
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
new file mode 100644
index 0000000000..cade5dcca8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
@@ -0,0 +1,32 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Recalculate the stats for all rooms after the fix to joined_members erroneously
+-- incrementing on per-room profile changes.
+
+-- Note that the populate_stats_process_rooms background update is already set to
+-- run if you're upgrading from Synapse <1.0.0.
+
+-- Additionally, if you've upgraded to v1.18.0 (which doesn't include this fix),
+-- this bg job runs, and then update to v1.19.0, you'd end up with only half of
+-- your rooms having room stats recalculated after this fix was in place.
+
+-- So we've switched the old `populate_stats_process_rooms` background job to a
+-- no-op, and then kick off a bg job with a new name, but with the same
+-- functionality as the old one. This effectively restarts the background job
+-- from the beginning, without running it twice in a row, supporting both
+-- upgrade usecases.
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('populate_stats_process_rooms_2', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql
new file mode 100644
index 0000000000..15421b99ac
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This table is no longer used.
+DROP TABLE IF EXISTS presence_allow_inbound;
diff --git a/synapse/storage/databases/main/schema/delta/58/15unread_count.sql b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
new file mode 100644
index 0000000000..317fba8a5d
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We're hijacking the push actions to store unread messages and unread counts (specified
+-- in MSC2654) because doing otherwise would result in either performance issues or
+-- reimplementing a consequent bit of the push actions.
+
+-- Add columns to event_push_actions and event_push_actions_staging to track unread
+-- messages and calculate unread counts.
+ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT;
+ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT;
+
+-- Add column to event_push_summary
+ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT;
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql b/synapse/storage/databases/main/schema/full_schemas/16/application_services.sql
index 883fcd10b2..883fcd10b2 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/application_services.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql b/synapse/storage/databases/main/schema/full_schemas/16/event_edges.sql
index 10ce2aa7a0..10ce2aa7a0 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/event_edges.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql b/synapse/storage/databases/main/schema/full_schemas/16/event_signatures.sql
index 95826da431..95826da431 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/event_signatures.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql b/synapse/storage/databases/main/schema/full_schemas/16/im.sql
index a1a2aa8e5b..a1a2aa8e5b 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/im.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql b/synapse/storage/databases/main/schema/full_schemas/16/keys.sql
index 11cdffdbb3..11cdffdbb3 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/keys.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql b/synapse/storage/databases/main/schema/full_schemas/16/media_repository.sql
index 8f3759bb2a..8f3759bb2a 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/media_repository.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql b/synapse/storage/databases/main/schema/full_schemas/16/presence.sql
index 01d2d8f833..01d2d8f833 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/presence.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql b/synapse/storage/databases/main/schema/full_schemas/16/profiles.sql
index c04f4747d9..c04f4747d9 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/profiles.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql b/synapse/storage/databases/main/schema/full_schemas/16/push.sql
index e44465cf45..e44465cf45 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/push.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql b/synapse/storage/databases/main/schema/full_schemas/16/redactions.sql
index 318f0d9aa5..318f0d9aa5 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/redactions.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql b/synapse/storage/databases/main/schema/full_schemas/16/room_aliases.sql
index d47da3b12f..d47da3b12f 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/room_aliases.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql b/synapse/storage/databases/main/schema/full_schemas/16/state.sql
index 96391a8f0e..96391a8f0e 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/state.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql b/synapse/storage/databases/main/schema/full_schemas/16/transactions.sql
index 17e67bedac..17e67bedac 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/transactions.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql b/synapse/storage/databases/main/schema/full_schemas/16/users.sql
index f013aa8b18..f013aa8b18 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/users.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
index 889a9a0ce4..889a9a0ce4 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede7e..a0411ede7e 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql b/synapse/storage/databases/main/schema/full_schemas/54/stream_positions.sql
index 91d21b2921..91d21b2921 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/54/stream_positions.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/databases/main/schema/full_schemas/README.md
index c00f287190..c00f287190 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/README.md
+++ b/synapse/storage/databases/main/schema/full_schemas/README.md
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/databases/main/search.py
index 13f49d8060..f01cf2fd02 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -16,17 +16,13 @@
import logging
import re
from collections import namedtuple
-
-from six import string_types
-
-from canonicaljson import json
-
-from twisted.internet import defer
+from typing import List, Optional, Set
from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
-from synapse.storage.database import Database
+from synapse.events import EventBase
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
logger = logging.getLogger(__name__)
@@ -92,16 +88,16 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs)
if not hs.config.enable_search:
return
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
)
@@ -110,16 +106,15 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# a GIN index. However, it's possible that some people might still have
# the background update queued, so we register a handler to clear the
# background update.
- self.db.updates.register_noop_background_update(
+ self.db_pool.updates.register_noop_background_update(
self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
)
- @defer.inlineCallbacks
- def _background_reindex_search(self, progress, batch_size):
+ async def _background_reindex_search(self, progress, batch_size):
# we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -144,7 +139,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# store_search_entries_txn with a generator function, but that
# would mean having two cursors open on the database at once.
# Instead we just build a list of results.
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if not rows:
return 0
@@ -159,7 +154,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
stream_ordering = row["stream_ordering"]
origin_server_ts = row["origin_server_ts"]
try:
- event_json = json.loads(row["json"])
+ event_json = db_to_json(row["json"])
content = event_json["content"]
except Exception:
continue
@@ -180,7 +175,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# skip over it.
continue
- if not isinstance(value, string_types):
+ if not isinstance(value, str):
# If the event body, name or topic isn't a string
# then skip over it
continue
@@ -204,23 +199,24 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
"rows_inserted": rows_inserted + len(event_search_rows),
}
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, self.EVENT_SEARCH_UPDATE_NAME, progress
)
return len(event_search_rows)
- result = yield self.db.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
)
if not result:
- yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
+ await self.db_pool.updates._end_background_update(
+ self.EVENT_SEARCH_UPDATE_NAME
+ )
return result
- @defer.inlineCallbacks
- def _background_reindex_gin_search(self, progress, batch_size):
+ async def _background_reindex_gin_search(self, progress, batch_size):
"""This handles old synapses which used GIST indexes, if any;
converting them back to be GIN as per the actual schema.
"""
@@ -257,15 +253,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
conn.set_session(autocommit=False)
if isinstance(self.database_engine, PostgresEngine):
- yield self.db.runWithConnection(create_index)
+ await self.db_pool.runWithConnection(create_index)
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
)
return 1
- @defer.inlineCallbacks
- def _background_reindex_search_order(self, progress, batch_size):
+ async def _background_reindex_search_order(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -290,14 +285,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
)
conn.set_session(autocommit=False)
- yield self.db.runWithConnection(create_index)
+ await self.db_pool.runWithConnection(create_index)
pg = dict(progress)
pg["have_added_indexes"] = True
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
- self.db.updates._background_update_progress_txn,
+ self.db_pool.updates._background_update_progress_txn,
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
pg,
)
@@ -327,18 +322,18 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
"have_added_indexes": True,
}
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress
)
return len(rows), True
- num_rows, finished = yield self.db.runInteraction(
+ num_rows, finished = await self.db_pool.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
)
if not finished:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_ORDER_UPDATE_NAME
)
@@ -346,11 +341,10 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
class SearchStore(SearchBackgroundUpdateStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SearchStore, self).__init__(database, db_conn, hs)
- @defer.inlineCallbacks
- def search_msgs(self, room_ids, search_term, keys):
+ async def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
Args:
@@ -427,15 +421,15 @@ class SearchStore(SearchBackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
- results = yield self.db.execute(
- "search_msgs", self.db.cursor_to_dict, sql, *args
+ results = await self.db_pool.execute(
+ "search_msgs", self.db_pool.cursor_to_dict, sql, *args
)
results = list(filter(lambda row: row["room_id"] in room_ids, results))
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
@@ -444,12 +438,12 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
- highlights = yield self._find_highlights_in_postgres(search_query, events)
+ highlights = await self._find_highlights_in_postgres(search_query, events)
count_sql += " GROUP BY room_id"
- count_results = yield self.db.execute(
- "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
+ count_results = await self.db_pool.execute(
+ "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
@@ -464,19 +458,25 @@ class SearchStore(SearchBackgroundUpdateStore):
"count": count,
}
- @defer.inlineCallbacks
- def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
+ async def search_rooms(
+ self,
+ room_ids: List[str],
+ search_term: str,
+ keys: List[str],
+ limit,
+ pagination_token: Optional[str] = None,
+ ) -> List[dict]:
"""Performs a full text search over events with given keys.
Args:
- room_id (list): The room_ids to search in
- search_term (str): Search term to search for
- keys (list): List of keys to search in, currently supports
- "content.body", "content.name", "content.topic"
- pagination_token (str): A pagination token previously returned
+ room_ids: The room_ids to search in
+ search_term: Search term to search for
+ keys: List of keys to search in, currently supports "content.body",
+ "content.name", "content.topic"
+ pagination_token: A pagination token previously returned
Returns:
- list of dicts
+ Each match as a dictionary.
"""
clauses = []
@@ -579,15 +579,15 @@ class SearchStore(SearchBackgroundUpdateStore):
args.append(limit)
- results = yield self.db.execute(
- "search_rooms", self.db.cursor_to_dict, sql, *args
+ results = await self.db_pool.execute(
+ "search_rooms", self.db_pool.cursor_to_dict, sql, *args
)
results = list(filter(lambda row: row["room_id"] in room_ids, results))
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
@@ -596,12 +596,12 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
- highlights = yield self._find_highlights_in_postgres(search_query, events)
+ highlights = await self._find_highlights_in_postgres(search_query, events)
count_sql += " GROUP BY room_id"
- count_results = yield self.db.execute(
- "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
+ count_results = await self.db_pool.execute(
+ "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
@@ -621,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore):
"count": count,
}
- def _find_highlights_in_postgres(self, search_query, events):
+ async def _find_highlights_in_postgres(
+ self, search_query: str, events: List[EventBase]
+ ) -> Set[str]:
"""Given a list of events and a search term, return a list of words
that match from the content of the event.
@@ -629,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore):
highlight the matching parts.
Args:
- search_query (str)
- events (list): A list of events
+ search_query
+ events: A list of events
Returns:
- deferred : A set of strings.
+ A set of strings.
"""
def f(txn):
@@ -686,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore):
return highlight_words
- return self.db.runInteraction("_find_highlights", f)
+ return await self.db_pool.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict):
diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 36244d9f5d..c8c67953e4 100644
--- a/synapse/storage/data_stores/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from unpaddedbase64 import encode_base64
+from typing import Dict, Iterable, List, Tuple
-from twisted.internet import defer
+from unpaddedbase64 import encode_base64
from synapse.storage._base import SQLBaseStore
+from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
@@ -31,18 +32,38 @@ class SignatureWorkerStore(SQLBaseStore):
@cachedList(
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
)
- def get_event_reference_hashes(self, event_ids):
+ async def get_event_reference_hashes(
+ self, event_ids: Iterable[str]
+ ) -> Dict[str, Dict[str, bytes]]:
+ """Get all hashes for given events.
+
+ Args:
+ event_ids: The event IDs to get hashes for.
+
+ Returns:
+ A mapping of event ID to a mapping of algorithm to hash.
+ """
+
def f(txn):
return {
event_id: self._get_event_reference_hashes_txn(txn, event_id)
for event_id in event_ids
}
- return self.db.runInteraction("get_event_reference_hashes", f)
+ return await self.db_pool.runInteraction("get_event_reference_hashes", f)
+
+ async def add_event_hashes(
+ self, event_ids: Iterable[str]
+ ) -> List[Tuple[str, Dict[str, str]]]:
+ """
- @defer.inlineCallbacks
- def add_event_hashes(self, event_ids):
- hashes = yield self.get_event_reference_hashes(event_ids)
+ Args:
+ event_ids: The event IDs
+
+ Returns:
+ A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
+ """
+ hashes = await self.get_event_reference_hashes(event_ids)
hashes = {
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
for e_id, h in hashes.items()
@@ -50,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore):
return list(hashes.items())
- def _get_event_reference_hashes_txn(self, txn, event_id):
+ def _get_event_reference_hashes_txn(
+ self, txn: Cursor, event_id: str
+ ) -> Dict[str, bytes]:
"""Get all the hashes for a given PDU.
Args:
- txn (cursor):
- event_id (str): Id for the Event.
+ txn:
+ event_id: Id for the Event.
Returns:
- A dict[unicode, bytes] of algorithm -> hash.
+ A mapping of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/databases/main/state.py
index 347cc50778..5c6168e301 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -16,17 +16,18 @@
import collections.abc
import logging
from collections import namedtuple
-
-from twisted.internet import defer
+from typing import Iterable, Optional, Set
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
+from synapse.types import StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@@ -54,7 +55,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers.
"""
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
async def get_room_version(self, room_id: str) -> RoomVersion:
@@ -93,7 +94,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# We really should have an entry in the rooms table for every room we
# care about, but let's be a bit paranoid (at least while the background
# update is happening) to avoid breaking existing rooms.
- version = await self.db.simple_select_one_onecol(
+ version = await self.db_pool.simple_select_one_onecol(
table="rooms",
keyvalues={"room_id": room_id},
retcol="room_version",
@@ -108,28 +109,27 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
create_event = await self.get_create_event_for_room(room_id)
return create_event.content.get("room_version", "1")
- @defer.inlineCallbacks
- def get_room_predecessor(self, room_id):
+ async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
Args:
- room_id (str)
+ room_id: The room ID.
Returns:
- Deferred[dict|None]: A dictionary containing the structure of the predecessor
- field from the room's create event. The structure is subject to other servers,
- but it is expected to be:
- * room_id (str): The room ID of the predecessor room
- * event_id (str): The ID of the tombstone event in the predecessor room
+ A dictionary containing the structure of the predecessor
+ field from the room's create event. The structure is subject to other servers,
+ but it is expected to be:
+ * room_id (str): The room ID of the predecessor room
+ * event_id (str): The ID of the tombstone event in the predecessor room
- None if a predecessor key is not found, or is not a dictionary.
+ None if a predecessor key is not found, or is not a dictionary.
Raises:
NotFoundError if the given room is unknown
"""
# Retrieve the room's create event
- create_event = yield self.get_create_event_for_room(room_id)
+ create_event = await self.get_create_event_for_room(room_id)
# Retrieve the predecessor key of the create event
predecessor = create_event.content.get("predecessor", None)
@@ -140,20 +140,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return predecessor
- @defer.inlineCallbacks
- def get_create_event_for_room(self, room_id):
+ async def get_create_event_for_room(self, room_id: str) -> EventBase:
"""Get the create state event for a room.
Args:
- room_id (str)
+ room_id: The room ID.
Returns:
- Deferred[EventBase]: The room creation event.
+ The room creation event.
Raises:
NotFoundError if the room is unknown
"""
- state_ids = yield self.get_current_state_ids(room_id)
+ state_ids = await self.get_current_state_ids(room_id)
create_id = state_ids.get((EventTypes.Create, ""))
# If we can't find the create event, assume we've hit a dead end
@@ -161,19 +160,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
raise NotFoundError("Unknown room %s" % (room_id,))
# Retrieve the room's create event and return
- create_event = yield self.get_event(create_id)
+ create_event = await self.get_event(create_id)
return create_event
@cached(max_entries=100000, iterable=True)
- def get_current_state_ids(self, room_id):
+ async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the
current_state_events table.
Args:
- room_id (str)
+ room_id: The room to get the state IDs of.
Returns:
- deferred: dict of (type, state_key) -> event_id
+ The current state of the room.
"""
def _get_current_state_ids_txn(txn):
@@ -186,14 +185,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn
)
# FIXME: how should this be cached?
- def get_filtered_current_state_ids(
+ async def get_filtered_current_state_ids(
self, room_id: str, state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
@@ -204,14 +203,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
from the database.
Returns:
- defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
+ Map from type/state_key to event ID.
"""
where_clause, where_args = state_filter.make_sql_filter_clause()
if not where_clause:
# We delegate to the cached version
- return self.get_current_state_ids(room_id)
+ return await self.get_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(txn):
results = {}
@@ -233,22 +232,21 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
- @defer.inlineCallbacks
- def get_canonical_alias_for_room(self, room_id):
+ async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
"""Get canonical alias for room, if any
Args:
- room_id (str)
+ room_id: The room ID
Returns:
- Deferred[str|None]: The canonical alias, if any
+ The canonical alias, if any
"""
- state = yield self.get_filtered_current_state_ids(
+ state = await self.get_filtered_current_state_ids(
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
)
@@ -256,15 +254,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not event_id:
return
- event = yield self.get_event(event_id, allow_none=True)
+ event = await self.get_event(event_id, allow_none=True)
if not event:
return
return event.content.get("canonical_alias")
@cached(max_entries=50000)
- def _get_state_group_for_event(self, event_id):
- return self.db.simple_select_one_onecol(
+ async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
+ return await self.db_pool.simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
@@ -276,12 +274,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
cached_method_name="_get_state_group_for_event",
list_name="event_ids",
num_args=1,
- inlineCallbacks=True,
)
- def _get_state_group_for_events(self, event_ids):
+ async def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
@@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {row["event_id"]: row["state_group"] for row in rows}
- @defer.inlineCallbacks
- def get_referenced_state_groups(self, state_groups):
+ async def get_referenced_state_groups(
+ self, state_groups: Iterable[int]
+ ) -> Set[int]:
"""Check if the state groups are referenced by events.
Args:
- state_groups (Iterable[int])
+ state_groups
Returns:
- Deferred[set[int]]: The subset of state groups that are
- referenced.
+ The subset of state groups that are referenced.
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="state_group",
iterable=state_groups,
@@ -322,25 +319,25 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.server_name = hs.hostname
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
table="current_state_events",
columns=["state_key"],
where_clause="type='m.room.member'",
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
index_name="event_to_state_groups_sg_index",
table="event_to_state_groups",
columns=["state_group"],
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms,
)
@@ -353,6 +350,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
last_room_id = progress.get("last_room_id", "")
def _background_remove_left_rooms_txn(txn):
+ # get a batch of room ids to consider
sql = """
SELECT DISTINCT room_id FROM current_state_events
WHERE room_id > ? ORDER BY room_id LIMIT ?
@@ -363,35 +361,79 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
if not room_ids:
return True, set()
+ ###########################################################################
+ #
+ # exclude rooms where we have active members
+
sql = """
SELECT room_id
- FROM current_state_events
+ FROM local_current_membership
WHERE
room_id > ? AND room_id <= ?
- AND type = 'm.room.member'
AND membership = 'join'
- AND state_key LIKE ?
GROUP BY room_id
"""
- txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name))
-
+ txn.execute(sql, (last_room_id, room_ids[-1]))
joined_room_ids = {row[0] for row in txn}
+ to_delete = set(room_ids) - joined_room_ids
+
+ ###########################################################################
+ #
+ # exclude rooms which we are in the process of constructing; these otherwise
+ # qualify as "rooms with no local users", and would have their
+ # forward extremities cleaned up.
+
+ # the following query will return a list of rooms which have forward
+ # extremities that are *not* also the create event in the room - ie
+ # those that are not being created currently.
+
+ sql = """
+ SELECT DISTINCT efe.room_id
+ FROM event_forward_extremities efe
+ LEFT JOIN current_state_events cse ON
+ cse.event_id = efe.event_id
+ AND cse.type = 'm.room.create'
+ AND cse.state_key = ''
+ WHERE
+ cse.event_id IS NULL
+ AND efe.room_id > ? AND efe.room_id <= ?
+ """
+
+ txn.execute(sql, (last_room_id, room_ids[-1]))
+
+ # build a set of those rooms within `to_delete` that do not appear in
+ # the above, leaving us with the rooms in `to_delete` that *are* being
+ # created.
+ creating_rooms = to_delete.difference(row[0] for row in txn)
+ logger.info("skipping rooms which are being created: %s", creating_rooms)
+
+ # now remove the rooms being created from the list of those to delete.
+ #
+ # (we could have just taken the intersection of `to_delete` with the result
+ # of the sql query, but it's useful to be able to log `creating_rooms`; and
+ # having done so, it's quicker to remove the (few) creating rooms from
+ # `to_delete` than it is to form the intersection with the (larger) list of
+ # not-creating-rooms)
+
+ to_delete -= creating_rooms
- left_rooms = set(room_ids) - joined_room_ids
+ ###########################################################################
+ #
+ # now clear the state for the rooms
- logger.info("Deleting current state left rooms: %r", left_rooms)
+ logger.info("Deleting current state left rooms: %r", to_delete)
# First we get all users that we still think were joined to the
# room. This is so that we can mark those device lists as
# potentially stale, since there may have been a period where the
# server didn't share a room with the remote user and therefore may
# have missed any device updates.
- rows = self.db.simple_select_many_txn(
+ rows = self.db_pool.simple_select_many_txn(
txn,
table="current_state_events",
column="room_id",
- iterable=left_rooms,
+ iterable=to_delete,
keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
retcols=("state_key",),
)
@@ -399,23 +441,23 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
potentially_left_users = {row["state_key"] for row in rows}
# Now lets actually delete the rooms from the DB.
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="current_state_events",
column="room_id",
- iterable=left_rooms,
+ iterable=to_delete,
keyvalues={},
)
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="event_forward_extremities",
column="room_id",
- iterable=left_rooms,
+ iterable=to_delete,
keyvalues={},
)
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn,
self.DELETE_CURRENT_STATE_UPDATE_NAME,
{"last_room_id": room_ids[-1]},
@@ -423,12 +465,12 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
return False, potentially_left_users
- finished, potentially_left_users = await self.db.runInteraction(
+ finished, potentially_left_users = await self.db_pool.runInteraction(
"_background_remove_left_rooms", _background_remove_left_rooms_txn
)
if finished:
- await self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.DELETE_CURRENT_STATE_UPDATE_NAME
)
@@ -463,5 +505,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
* `state_groups_state`: Maps state group to state events.
"""
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(StateStore, self).__init__(database, db_conn, hs)
diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 725e12507f..356623fc6e 100644
--- a/synapse/storage/data_stores/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -14,8 +14,7 @@
# limitations under the License.
import logging
-
-from twisted.internet import defer
+from typing import Any, Dict, List, Tuple
from synapse.storage._base import SQLBaseStore
@@ -23,7 +22,9 @@ logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
- def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
+ async def get_current_state_deltas(
+ self, prev_stream_id: int, max_stream_id: int
+ ) -> Tuple[int, List[Dict[str, Any]]]:
"""Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields:
@@ -37,12 +38,12 @@ class StateDeltasStore(SQLBaseStore):
if it's new state.
Args:
- prev_stream_id (int): point to get changes since (exclusive)
- max_stream_id (int): the point that we know has been correctly persisted
+ prev_stream_id: point to get changes since (exclusive)
+ max_stream_id: the point that we know has been correctly persisted
- ie, an upper limit to return changes from.
Returns:
- Deferred[tuple[int, list[dict]]: A tuple consisting of:
+ A tuple consisting of:
- the stream id which these results go up to
- list of current_state_delta_stream rows. If it is empty, we are
up to date.
@@ -58,7 +59,7 @@ class StateDeltasStore(SQLBaseStore):
# if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and
# max_stream_id.
- return defer.succeed((max_stream_id, []))
+ return (max_stream_id, [])
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
@@ -100,22 +101,22 @@ class StateDeltasStore(SQLBaseStore):
ORDER BY stream_id ASC
"""
txn.execute(sql, (prev_stream_id, clipped_stream_id))
- return clipped_stream_id, self.db.cursor_to_dict(txn)
+ return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
)
def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
- return self.db.simple_select_one_onecol_txn(
+ return self.db_pool.simple_select_one_onecol_txn(
txn,
table="current_state_delta_stream",
keyvalues={},
retcol="COALESCE(MAX(stream_id), -1)",
)
- def get_max_stream_id_in_current_state_deltas(self):
- return self.db.runInteraction(
+ async def get_max_stream_id_in_current_state_deltas(self):
+ return await self.db_pool.runInteraction(
"get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn,
)
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/databases/main/stats.py
index 380c1ec7da..55a250ef06 100644
--- a/synapse/storage/data_stores/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -15,14 +15,15 @@
# limitations under the License.
import logging
+from collections import Counter
from itertools import chain
+from typing import Any, Dict, List, Optional, Tuple
-from twisted.internet import defer
from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventTypes, Membership
-from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine
from synapse.util.caches.descriptors import cached
@@ -59,7 +60,7 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
class StatsStore(StateDeltasStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(StatsStore, self).__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -69,17 +70,20 @@ class StatsStore(StateDeltasStore):
self.stats_delta_processing_lock = DeferredLock()
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"populate_stats_process_rooms", self._populate_stats_process_rooms
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
+ "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2
+ )
+ self.db_pool.updates.register_background_update_handler(
"populate_stats_process_users", self._populate_stats_process_users
)
# we no longer need to perform clean-up, but we will give ourselves
# the potential to reintroduce it in the future – so documentation
# will still encourage the use of this no-op handler.
- self.db.updates.register_noop_background_update("populate_stats_cleanup")
- self.db.updates.register_noop_background_update("populate_stats_prepare")
+ self.db_pool.updates.register_noop_background_update("populate_stats_cleanup")
+ self.db_pool.updates.register_noop_background_update("populate_stats_prepare")
def quantise_stats_time(self, ts):
"""
@@ -97,13 +101,14 @@ class StatsStore(StateDeltasStore):
"""
return (ts // self.stats_bucket_size) * self.stats_bucket_size
- @defer.inlineCallbacks
- def _populate_stats_process_users(self, progress, batch_size):
+ async def _populate_stats_process_users(self, progress, batch_size):
"""
This is a background update which regenerates statistics for users.
"""
if not self.stats_enabled:
- yield self.db.updates._end_background_update("populate_stats_process_users")
+ await self.db_pool.updates._end_background_update(
+ "populate_stats_process_users"
+ )
return 1
last_user_id = progress.get("last_user_id", "")
@@ -118,35 +123,57 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_user_id, batch_size))
return [r for r, in txn]
- users_to_work_on = yield self.db.runInteraction(
+ users_to_work_on = await self.db_pool.runInteraction(
"_populate_stats_process_users", _get_next_batch
)
# No more rooms -- complete the transaction.
if not users_to_work_on:
- yield self.db.updates._end_background_update("populate_stats_process_users")
+ await self.db_pool.updates._end_background_update(
+ "populate_stats_process_users"
+ )
return 1
for user_id in users_to_work_on:
- yield self._calculate_and_set_initial_state_for_user(user_id)
+ await self._calculate_and_set_initial_state_for_user(user_id)
progress["last_user_id"] = user_id
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"populate_stats_process_users",
- self.db.updates._background_update_progress_txn,
+ self.db_pool.updates._background_update_progress_txn,
"populate_stats_process_users",
progress,
)
return len(users_to_work_on)
- @defer.inlineCallbacks
- def _populate_stats_process_rooms(self, progress, batch_size):
+ async def _populate_stats_process_rooms(self, progress, batch_size):
+ """
+ This was a background update which regenerated statistics for rooms.
+
+ It has been replaced by StatsStore._populate_stats_process_rooms_2. This background
+ job has been scheduled to run as part of Synapse v1.0.0, and again now. To ensure
+ someone upgrading from <v1.0.0, this background task has been turned into a no-op
+ so that the potentially expensive task is not run twice.
+
+ Further context: https://github.com/matrix-org/synapse/pull/7977
+ """
+ await self.db_pool.updates._end_background_update(
+ "populate_stats_process_rooms"
+ )
+ return 1
+
+ async def _populate_stats_process_rooms_2(self, progress, batch_size):
"""
This is a background update which regenerates statistics for rooms.
+
+ It replaces StatsStore._populate_stats_process_rooms. See its docstring for the
+ reasoning.
"""
if not self.stats_enabled:
- yield self.db.updates._end_background_update("populate_stats_process_rooms")
+ await self.db_pool.updates._end_background_update(
+ "populate_stats_process_rooms_2"
+ )
return 1
last_room_id = progress.get("last_room_id", "")
@@ -161,48 +188,68 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_room_id, batch_size))
return [r for r, in txn]
- rooms_to_work_on = yield self.db.runInteraction(
- "populate_stats_rooms_get_batch", _get_next_batch
+ rooms_to_work_on = await self.db_pool.runInteraction(
+ "populate_stats_rooms_2_get_batch", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
- yield self.db.updates._end_background_update("populate_stats_process_rooms")
+ await self.db_pool.updates._end_background_update(
+ "populate_stats_process_rooms_2"
+ )
return 1
for room_id in rooms_to_work_on:
- yield self._calculate_and_set_initial_state_for_room(room_id)
+ await self._calculate_and_set_initial_state_for_room(room_id)
progress["last_room_id"] = room_id
- yield self.db.runInteraction(
- "_populate_stats_process_rooms",
- self.db.updates._background_update_progress_txn,
- "populate_stats_process_rooms",
+ await self.db_pool.runInteraction(
+ "_populate_stats_process_rooms_2",
+ self.db_pool.updates._background_update_progress_txn,
+ "populate_stats_process_rooms_2",
progress,
)
return len(rooms_to_work_on)
- def get_stats_positions(self):
+ async def get_stats_positions(self) -> int:
"""
Returns the stats processor positions.
"""
- return self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="stats_incremental_position",
keyvalues={},
retcol="stream_id",
desc="stats_incremental_position",
)
- def update_room_state(self, room_id, fields):
- """
+ async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None:
+ """Update the state of a room.
+
+ fields can contain the following keys with string values:
+ * join_rules
+ * history_visibility
+ * encryption
+ * name
+ * topic
+ * avatar
+ * canonical_alias
+
+ A is_federatable key can also be included with a boolean value.
+
Args:
- room_id (str)
- fields (dict[str:Any])
+ room_id: The room ID to update the state of.
+ fields: The fields to update. This can include a partial list of the
+ above fields to only update some room information.
"""
-
- # For whatever reason some of the fields may contain null bytes, which
- # postgres isn't a fan of, so we replace those fields with null.
+ # Ensure that the values to update are valid, they should be strings and
+ # not contain any null bytes.
+ #
+ # Invalid data gets overwritten with null.
+ #
+ # Note that a missing value should not be overwritten (it keeps the
+ # previous value).
+ sentinel = object()
for col in (
"join_rules",
"history_visibility",
@@ -212,32 +259,34 @@ class StatsStore(StateDeltasStore):
"avatar",
"canonical_alias",
):
- field = fields.get(col)
- if field and "\0" in field:
+ field = fields.get(col, sentinel)
+ if field is not sentinel and (not isinstance(field, str) or "\0" in field):
fields[col] = None
- return self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
table="room_stats_state",
keyvalues={"room_id": room_id},
values=fields,
desc="update_room_state",
)
- def get_statistics_for_subject(self, stats_type, stats_id, start, size=100):
+ async def get_statistics_for_subject(
+ self, stats_type: str, stats_id: str, start: str, size: int = 100
+ ) -> List[dict]:
"""
Get statistics for a given subject.
Args:
- stats_type (str): The type of subject
- stats_id (str): The ID of the subject (e.g. room_id or user_id)
- start (int): Pagination start. Number of entries, not timestamp.
- size (int): How many entries to return.
+ stats_type: The type of subject
+ stats_id: The ID of the subject (e.g. room_id or user_id)
+ start: Pagination start. Number of entries, not timestamp.
+ size: How many entries to return.
Returns:
- Deferred[list[dict]], where the dict has the keys of
+ A list of dicts, where the dict has the keys of
ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts".
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_statistics_for_subject",
self._get_statistics_for_subject_txn,
stats_type,
@@ -258,7 +307,7 @@ class StatsStore(StateDeltasStore):
ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type]
)
- slice_list = self.db.simple_select_list_paginate_txn(
+ slice_list = self.db_pool.simple_select_list_paginate_txn(
txn,
table + "_historical",
"end_ts",
@@ -272,7 +321,7 @@ class StatsStore(StateDeltasStore):
return slice_list
@cached()
- def get_earliest_token_for_stats(self, stats_type, id):
+ async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
"""
Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the
@@ -280,29 +329,28 @@ class StatsStore(StateDeltasStore):
being calculated.
Returns:
- Deferred[int]
+ The earliest token.
"""
table, id_col = TYPE_TO_TABLE[stats_type]
- return self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
"%s_current" % (table,),
keyvalues={id_col: id},
retcol="completed_delta_stream_id",
allow_none=True,
)
- def bulk_update_stats_delta(self, ts, updates, stream_id):
+ async def bulk_update_stats_delta(
+ self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
+ ) -> None:
"""Bulk update stats tables for a given stream_id and updates the stats
incremental position.
Args:
- ts (int): Current timestamp in ms
- updates(dict[str, dict[str, dict[str, Counter]]]): The updates to
- commit as a mapping stats_type -> stats_id -> field -> delta.
- stream_id (int): Current position.
-
- Returns:
- Deferred
+ ts: Current timestamp in ms
+ updates: The updates to commit as a mapping of
+ stats_type -> stats_id -> field -> delta.
+ stream_id: Current position.
"""
def _bulk_update_stats_delta_txn(txn):
@@ -320,45 +368,44 @@ class StatsStore(StateDeltasStore):
complete_with_stream_id=stream_id,
)
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="stats_incremental_position",
keyvalues={},
updatevalues={"stream_id": stream_id},
)
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"bulk_update_stats_delta", _bulk_update_stats_delta_txn
)
- def update_stats_delta(
+ async def update_stats_delta(
self,
- ts,
- stats_type,
- stats_id,
- fields,
- complete_with_stream_id,
- absolute_field_overrides=None,
- ):
+ ts: int,
+ stats_type: str,
+ stats_id: str,
+ fields: Dict[str, int],
+ complete_with_stream_id: Optional[int],
+ absolute_field_overrides: Optional[Dict[str, int]] = None,
+ ) -> None:
"""
Updates the statistics for a subject, with a delta (difference/relative
change).
Args:
- ts (int): timestamp of the change
- stats_type (str): "room" or "user" – the kind of subject
- stats_id (str): the subject's ID (room ID or user ID)
- fields (dict[str, int]): Deltas of stats values.
- complete_with_stream_id (int, optional):
+ ts: timestamp of the change
+ stats_type: "room" or "user" – the kind of subject
+ stats_id: the subject's ID (room ID or user ID)
+ fields: Deltas of stats values.
+ complete_with_stream_id:
If supplied, converts an incomplete row into a complete row,
with the supplied stream_id marked as the stream_id where the
row was completed.
- absolute_field_overrides (dict[str, int]): Current stats values
- (i.e. not deltas) of absolute fields.
- Does not work with per-slice fields.
+ absolute_field_overrides: Current stats values (i.e. not deltas) of
+ absolute fields. Does not work with per-slice fields.
"""
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"update_stats_delta",
self._update_stats_delta_txn,
ts,
@@ -493,17 +540,17 @@ class StatsStore(StateDeltasStore):
else:
self.database_engine.lock_table(txn, table)
retcols = list(chain(absolutes.keys(), additive_relatives.keys()))
- current_row = self.db.simple_select_one_txn(
+ current_row = self.db_pool.simple_select_one_txn(
txn, table, keyvalues, retcols, allow_none=True
)
if current_row is None:
merged_dict = {**keyvalues, **absolutes, **additive_relatives}
- self.db.simple_insert_txn(txn, table, merged_dict)
+ self.db_pool.simple_insert_txn(txn, table, merged_dict)
else:
for (key, val) in additive_relatives.items():
current_row[key] += val
current_row.update(absolutes)
- self.db.simple_update_one_txn(txn, table, keyvalues, current_row)
+ self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row)
def _upsert_copy_from_table_with_additive_relatives_txn(
self,
@@ -590,11 +637,11 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, qargs)
else:
self.database_engine.lock_table(txn, into_table)
- src_row = self.db.simple_select_one_txn(
+ src_row = self.db_pool.simple_select_one_txn(
txn, src_table, keyvalues, copy_columns
)
all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues}
- dest_current_row = self.db.simple_select_one_txn(
+ dest_current_row = self.db_pool.simple_select_one_txn(
txn,
into_table,
keyvalues=all_dest_keyvalues,
@@ -610,25 +657,28 @@ class StatsStore(StateDeltasStore):
**src_row,
**additive_relatives,
}
- self.db.simple_insert_txn(txn, into_table, merged_dict)
+ self.db_pool.simple_insert_txn(txn, into_table, merged_dict)
else:
for (key, val) in additive_relatives.items():
src_row[key] = dest_current_row[key] + val
- self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
+ self.db_pool.simple_update_txn(
+ txn, into_table, all_dest_keyvalues, src_row
+ )
- def get_changes_room_total_events_and_bytes(self, min_pos, max_pos):
+ async def get_changes_room_total_events_and_bytes(
+ self, min_pos: int, max_pos: int
+ ) -> Dict[str, Dict[str, int]]:
"""Fetches the counts of events in the given range of stream IDs.
Args:
- min_pos (int)
- max_pos (int)
+ min_pos
+ max_pos
Returns:
- Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field
- changes.
+ Mapping of room ID to field changes.
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"stats_incremental_total_events_and_bytes",
self.get_changes_room_total_events_and_bytes_txn,
min_pos,
@@ -696,22 +746,22 @@ class StatsStore(StateDeltasStore):
return room_deltas, user_deltas
- @defer.inlineCallbacks
- def _calculate_and_set_initial_state_for_room(self, room_id):
+ async def _calculate_and_set_initial_state_for_room(
+ self, room_id: str
+ ) -> Tuple[dict, dict, int]:
"""Calculate and insert an entry into room_stats_current.
Args:
- room_id (str)
+ room_id: The room ID under calculation.
Returns:
- Deferred[tuple[dict, dict, int]]: A tuple of room state, membership
- counts and stream position.
+ A tuple of room state, membership counts and stream position.
"""
def _fetch_current_state_stats(txn):
pos = self.get_room_max_stream_ordering()
- rows = self.db.simple_select_many_txn(
+ rows = self.db_pool.simple_select_many_txn(
txn,
table="current_state_events",
column="type",
@@ -767,11 +817,11 @@ class StatsStore(StateDeltasStore):
current_state_events_count,
users_in_room,
pos,
- ) = yield self.db.runInteraction(
+ ) = await self.db_pool.runInteraction(
"get_initial_state_for_room", _fetch_current_state_stats
)
- state_event_map = yield self.get_events(event_ids, get_prev_content=False)
+ state_event_map = await self.get_events(event_ids, get_prev_content=False)
room_state = {
"join_rules": None,
@@ -806,11 +856,11 @@ class StatsStore(StateDeltasStore):
event.content.get("m.federate", True) is True
)
- yield self.update_room_state(room_id, room_state)
+ await self.update_room_state(room_id, room_state)
local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)]
- yield self.update_stats_delta(
+ await self.update_stats_delta(
ts=self.clock.time_msec(),
stats_type="room",
stats_id=room_id,
@@ -826,8 +876,7 @@ class StatsStore(StateDeltasStore):
},
)
- @defer.inlineCallbacks
- def _calculate_and_set_initial_state_for_user(self, user_id):
+ async def _calculate_and_set_initial_state_for_user(self, user_id):
def _calculate_and_set_initial_state_for_user_txn(txn):
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
@@ -842,12 +891,12 @@ class StatsStore(StateDeltasStore):
(count,) = txn.fetchone()
return count, pos
- joined_rooms, pos = yield self.db.runInteraction(
+ joined_rooms, pos = await self.db_pool.runInteraction(
"calculate_and_set_initial_state_for_user",
_calculate_and_set_initial_state_for_user_txn,
)
- yield self.update_stats_delta(
+ await self.update_stats_delta(
ts=self.clock.time_msec(),
stats_type="user",
stats_id=user_id,
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/databases/main/stream.py
index e89f0bffb5..be6df8a6d1 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,19 +39,27 @@ what sort order was used:
import abc
import logging
from collections import namedtuple
-
-from six.moves import range
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from twisted.internet import defer
+from synapse.api.filtering import Filter
+from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.database import Database
-from synapse.storage.engines import PostgresEngine
-from synapse.types import RoomStreamToken
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.types import Collection, RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -69,8 +77,12 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause(
- direction, column_names, from_token, to_token, engine
-):
+ direction: str,
+ column_names: Tuple[str, str],
+ from_token: Optional[Tuple[int, int]],
+ to_token: Optional[Tuple[int, int]],
+ engine: BaseDatabaseEngine,
+) -> str:
"""Creates an SQL expression to bound the columns by the pagination
tokens.
@@ -91,21 +103,19 @@ def generate_pagination_where_clause(
token, but include those that match the to token.
Args:
- direction (str): Whether we're paginating backwards("b") or
- forwards ("f").
- column_names (tuple[str, str]): The column names to bound. Must *not*
- be user defined as these get inserted directly into the SQL
- statement without escapes.
- from_token (tuple[int, int]|None): The start point for the pagination.
- This is an exclusive minimum bound if direction is "f", and an
- inclusive maximum bound if direction is "b".
- to_token (tuple[int, int]|None): The endpoint point for the pagination.
- This is an inclusive maximum bound if direction is "f", and an
- exclusive minimum bound if direction is "b".
+ direction: Whether we're paginating backwards("b") or forwards ("f").
+ column_names: The column names to bound. Must *not* be user defined as
+ these get inserted directly into the SQL statement without escapes.
+ from_token: The start point for the pagination. This is an exclusive
+ minimum bound if direction is "f", and an inclusive maximum bound if
+ direction is "b".
+ to_token: The endpoint point for the pagination. This is an inclusive
+ maximum bound if direction is "f", and an exclusive minimum bound if
+ direction is "b".
engine: The database engine to generate the clauses for
Returns:
- str: The sql expression
+ The sql expression
"""
assert direction in ("b", "f")
@@ -133,7 +143,12 @@ def generate_pagination_where_clause(
return " AND ".join(where_clause)
-def _make_generic_sql_bound(bound, column_names, values, engine):
+def _make_generic_sql_bound(
+ bound: str,
+ column_names: Tuple[str, str],
+ values: Tuple[Optional[int], int],
+ engine: BaseDatabaseEngine,
+) -> str:
"""Create an SQL expression that bounds the given column names by the
values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
@@ -143,18 +158,18 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
out manually.
Args:
- bound (str): The comparison operator to use. One of ">", "<", ">=",
+ bound: The comparison operator to use. One of ">", "<", ">=",
"<=", where the values are on the left and columns on the right.
- names (tuple[str, str]): The column names. Must *not* be user defined
+ names: The column names. Must *not* be user defined
as these get inserted directly into the SQL statement without
escapes.
- values (tuple[int|None, int]): The values to bound the columns by. If
+ values: The values to bound the columns by. If
the first value is None then only creates a bound on the second
column.
engine: The database engine to generate the SQL for
Returns:
- str
+ The SQL statement
"""
assert bound in (">", "<", ">=", "<=")
@@ -194,7 +209,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
)
-def filter_to_clause(event_filter):
+def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
@@ -252,11 +267,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
__metaclass__ = abc.ABCMeta
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
+ self._instance_name = hs.get_instance_name()
+ self._send_federation = hs.should_send_federation()
+ self._federation_shard_config = hs.config.worker.federation_shard_config
+
+ # If we're a process that sends federation we may need to reset the
+ # `federation_stream_position` table to match the current sharding
+ # config. We don't do this now as otherwise two processes could conflict
+ # during startup which would cause one to die.
+ self._need_to_reset_federation_stream_positions = self._send_federation
+
events_max = self.get_room_max_stream_ordering()
- event_cache_prefill, min_event_val = self.db.get_cache_dict(
+ event_cache_prefill, min_event_val = self.db_pool.get_cache_dict(
db_conn,
"events",
entity_column="room_id",
@@ -275,41 +300,42 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._stream_order_on_start = self.get_room_max_stream_ordering()
@abc.abstractmethod
- def get_room_max_stream_ordering(self):
+ def get_room_max_stream_ordering(self) -> int:
raise NotImplementedError()
@abc.abstractmethod
- def get_room_min_stream_ordering(self):
+ def get_room_min_stream_ordering(self) -> int:
raise NotImplementedError()
- @defer.inlineCallbacks
- def get_room_events_stream_for_rooms(
- self, room_ids, from_key, to_key, limit=0, order="DESC"
- ):
+ async def get_room_events_stream_for_rooms(
+ self,
+ room_ids: Collection[str],
+ from_key: str,
+ to_key: str,
+ limit: int = 0,
+ order: str = "DESC",
+ ) -> Dict[str, Tuple[List[EventBase], str]]:
"""Get new room events in stream ordering since `from_key`.
Args:
- room_id (str)
- from_key (str): Token from which no events are returned before
- to_key (str): Token from which no events are returned after. (This
+ room_ids
+ from_key: Token from which no events are returned before
+ to_key: Token from which no events are returned after. (This
is typically the current stream token)
- limit (int): Maximum number of events to return
- order (str): Either "DESC" or "ASC". Determines which events are
+ limit: Maximum number of events to return
+ order: Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the
oldest `limit` events.
Returns:
- Deferred[dict[str,tuple[list[FrozenEvent], str]]]
- A map from room id to a tuple containing:
- - list of recent events in the room
- - stream ordering key for the start of the chunk of events returned.
+ A map from room id to a tuple containing:
+ - list of recent events in the room
+ - stream ordering key for the start of the chunk of events returned.
"""
from_id = RoomStreamToken.parse_stream_token(from_key).stream
- room_ids = yield self._events_stream_cache.get_entities_changed(
- room_ids, from_id
- )
+ room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
if not room_ids:
return {}
@@ -317,7 +343,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
- res = yield make_deferred_yieldable(
+ res = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -337,43 +363,47 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
- def get_rooms_that_changed(self, room_ids, from_key):
+ def get_rooms_that_changed(
+ self, room_ids: Collection[str], from_key: str
+ ) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
Args:
- room_ids (list)
- from_key (str): The room_key portion of a StreamToken
+ room_ids
+ from_key: The room_key portion of a StreamToken
"""
- from_key = RoomStreamToken.parse_stream_token(from_key).stream
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
return {
room_id
for room_id in room_ids
- if self._events_stream_cache.has_entity_changed(room_id, from_key)
+ if self._events_stream_cache.has_entity_changed(room_id, from_id)
}
- @defer.inlineCallbacks
- def get_room_events_stream_for_room(
- self, room_id, from_key, to_key, limit=0, order="DESC"
- ):
-
+ async def get_room_events_stream_for_room(
+ self,
+ room_id: str,
+ from_key: str,
+ to_key: str,
+ limit: int = 0,
+ order: str = "DESC",
+ ) -> Tuple[List[EventBase], str]:
"""Get new room events in stream ordering since `from_key`.
Args:
- room_id (str)
- from_key (str): Token from which no events are returned before
- to_key (str): Token from which no events are returned after. (This
+ room_id
+ from_key: Token from which no events are returned before
+ to_key: Token from which no events are returned after. (This
is typically the current stream token)
- limit (int): Maximum number of events to return
- order (str): Either "DESC" or "ASC". Determines which events are
+ limit: Maximum number of events to return
+ order: Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the
oldest `limit` events.
Returns:
- Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
- events (in ascending order) and the token from the start of
- the chunk of events returned.
+ The list of events (in ascending order) and the token from the start
+ of the chunk of events returned.
"""
if from_key == to_key:
return [], from_key
@@ -381,9 +411,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
- has_changed = yield self._events_stream_cache.has_entity_changed(
- room_id, from_id
- )
+ has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
if not has_changed:
return [], from_key
@@ -401,9 +429,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
- rows = yield self.db.runInteraction("get_room_events_stream_for_room", f)
+ rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
- ret = yield self.get_events_as_list(
+ ret = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -421,8 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
- @defer.inlineCallbacks
- def get_membership_changes_for_user(self, user_id, from_key, to_key):
+ async def get_membership_changes_for_user(
+ self, user_id: str, from_key: str, to_key: str
+ ) -> List[EventBase]:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@@ -451,9 +480,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows
- rows = yield self.db.runInteraction("get_membership_changes_for_user", f)
+ rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f)
- ret = yield self.get_events_as_list(
+ ret = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -461,27 +490,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret
- @defer.inlineCallbacks
- def get_recent_events_for_room(self, room_id, limit, end_token):
+ async def get_recent_events_for_room(
+ self, room_id: str, limit: int, end_token: str
+ ) -> Tuple[List[EventBase], str]:
"""Get the most recent events in the room in topological ordering.
Args:
- room_id (str)
- limit (int)
- end_token (str): The stream token representing now.
+ room_id
+ limit
+ end_token: The stream token representing now.
Returns:
- Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
- events and a token pointing to the start of the returned
- events.
- The events returned are in ascending order.
+ A list of events and a token pointing to the start of the returned
+ events. The events returned are in ascending order.
"""
- rows, token = yield self.get_recent_event_ids_for_room(
+ rows, token = await self.get_recent_event_ids_for_room(
room_id, limit, end_token
)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -489,20 +517,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token)
- @defer.inlineCallbacks
- def get_recent_event_ids_for_room(self, room_id, limit, end_token):
+ async def get_recent_event_ids_for_room(
+ self, room_id: str, limit: int, end_token: str
+ ) -> Tuple[List[_EventDictReturn], str]:
"""Get the most recent events in the room in topological ordering.
Args:
- room_id (str)
- limit (int)
- end_token (str): The stream token representing now.
+ room_id
+ limit
+ end_token: The stream token representing now.
Returns:
- Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
- _EventDictReturn and a token pointing to the start of the returned
- events.
- The events returned are in ascending order.
+ A list of _EventDictReturn and a token pointing to the start of the
+ returned events. The events returned are in ascending order.
"""
# Allow a zero limit here, and no-op.
if limit == 0:
@@ -510,7 +537,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token)
- rows, token = yield self.db.runInteraction(
+ rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
@@ -523,16 +550,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, token
- def get_room_event_before_stream_ordering(self, room_id, stream_ordering):
+ async def get_room_event_before_stream_ordering(
+ self, room_id: str, stream_ordering: int
+ ) -> Tuple[int, int, str]:
"""Gets details of the first event in a room at or before a stream ordering
Args:
- room_id (str):
- stream_ordering (int):
+ room_id:
+ stream_ordering:
Returns:
- Deferred[(int, int, str)]:
- (stream ordering, topological ordering, event_id)
+ A tuple of (stream ordering, topological ordering, event_id)
"""
def _f(txn):
@@ -547,76 +575,100 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone()
- return self.db.runInteraction("get_room_event_before_stream_ordering", _f)
+ return await self.db_pool.runInteraction(
+ "get_room_event_before_stream_ordering", _f
+ )
- @defer.inlineCallbacks
- def get_room_events_max_id(self, room_id=None):
+ async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
"""Returns the current token for rooms stream.
By default, it returns the current global stream token. Specifying a
`room_id` causes it to return the current room specific topological
token.
"""
- token = yield self.get_room_max_stream_ordering()
+ token = self.get_room_max_stream_ordering()
if room_id is None:
return "s%d" % (token,)
else:
- topo = yield self.db.runInteraction(
+ topo = await self.db_pool.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
return "t%d-%d" % (topo, token)
- def get_stream_token_for_event(self, event_id):
+ async def get_stream_id_for_event(self, event_id: str) -> int:
+ """The stream ID for an event
+ Args:
+ event_id: The id of the event to look up a stream token for.
+ Raises:
+ StoreError if the event wasn't in the database.
+ Returns:
+ A stream ID.
+ """
+ return await self.db_pool.runInteraction(
+ "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
+ )
+
+ def get_stream_id_for_event_txn(
+ self, txn: LoggingTransaction, event_id: str, allow_none=False,
+ ) -> int:
+ return self.db_pool.simple_select_one_onecol_txn(
+ txn=txn,
+ table="events",
+ keyvalues={"event_id": event_id},
+ retcol="stream_ordering",
+ allow_none=allow_none,
+ )
+
+ async def get_stream_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
Args:
- event_id(str): The id of the event to look up a stream token for.
+ event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
- A deferred "s%d" stream token.
+ A "s%d" stream token.
"""
- return self.db.simple_select_one_onecol(
- table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
- ).addCallback(lambda row: "s%d" % (row,))
+ stream_id = await self.get_stream_id_for_event(event_id)
+ return "s%d" % (stream_id,)
- def get_topological_token_for_event(self, event_id):
+ async def get_topological_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
Args:
- event_id(str): The id of the event to look up a stream token for.
+ event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
- A deferred "t%d-%d" topological token.
+ A "t%d-%d" topological token.
"""
- return self.db.simple_select_one(
+ row = await self.db_pool.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
- ).addCallback(
- lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
)
+ return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
- def get_max_topological_token(self, room_id, stream_key):
+ async def get_max_topological_token(self, room_id: str, stream_key: int) -> int:
"""Get the max topological token in a room before the given stream
ordering.
Args:
- room_id (str)
- stream_key (int)
+ room_id
+ stream_key
Returns:
- Deferred[int]
+ The maximum topological token.
"""
sql = (
"SELECT coalesce(max(topological_ordering), 0) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
- return self.db.execute(
+ row = await self.db_pool.execute(
"get_max_topological_token", None, sql, room_id, stream_key
- ).addCallback(lambda r: r[0][0] if r else 0)
+ )
+ return row[0][0] if row else 0
- def _get_max_topological_txn(self, txn, room_id):
+ def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
txn.execute(
"SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
(room_id,),
@@ -626,16 +678,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows[0][0] if rows else 0
@staticmethod
- def _set_before_and_after(events, rows, topo_order=True):
+ def _set_before_and_after(
+ events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
+ ):
"""Inserts ordering information to events' internal metadata from
the DB rows.
Args:
- events (list[FrozenEvent])
- rows (list[_EventDictReturn])
- topo_order (bool): Whether the events were ordered topologically
- or by stream ordering. If true then all rows should have a non
- null topological_ordering.
+ events
+ rows
+ topo_order: Whether the events were ordered topologically or by stream
+ ordering. If true then all rows should have a non null
+ topological_ordering.
"""
for event, row in zip(events, rows):
stream = row.stream_ordering
@@ -648,25 +702,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
internal.after = str(RoomStreamToken(topo, stream))
internal.order = (int(topo) if topo else 0, int(stream))
- @defer.inlineCallbacks
- def get_events_around(
- self, room_id, event_id, before_limit, after_limit, event_filter=None
- ):
+ async def get_events_around(
+ self,
+ room_id: str,
+ event_id: str,
+ before_limit: int,
+ after_limit: int,
+ event_filter: Optional[Filter] = None,
+ ) -> dict:
"""Retrieve events and pagination tokens around a given event in a
room.
-
- Args:
- room_id (str)
- event_id (str)
- before_limit (int)
- after_limit (int)
- event_filter (Filter|None)
-
- Returns:
- dict
"""
- results = yield self.db.runInteraction(
+ results = await self.db_pool.runInteraction(
"get_events_around",
self._get_events_around_txn,
room_id,
@@ -676,11 +724,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
- events_before = yield self.get_events_as_list(
+ events_before = await self.get_events_as_list(
list(results["before"]["event_ids"]), get_prev_content=True
)
- events_after = yield self.get_events_as_list(
+ events_after = await self.get_events_as_list(
list(results["after"]["event_ids"]), get_prev_content=True
)
@@ -692,29 +740,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
}
def _get_events_around_txn(
- self, txn, room_id, event_id, before_limit, after_limit, event_filter
- ):
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ event_id: str,
+ before_limit: int,
+ after_limit: int,
+ event_filter: Optional[Filter],
+ ) -> dict:
"""Retrieves event_ids and pagination tokens around a given event in a
room.
Args:
- room_id (str)
- event_id (str)
- before_limit (int)
- after_limit (int)
- event_filter (Filter|None)
+ room_id
+ event_id
+ before_limit
+ after_limit
+ event_filter
Returns:
dict
"""
- results = self.db.simple_select_one_txn(
+ results = self.db_pool.simple_select_one_txn(
txn,
"events",
keyvalues={"event_id": event_id, "room_id": room_id},
retcols=["stream_ordering", "topological_ordering"],
)
+ # This cannot happen as `allow_none=False`.
+ assert results is not None
+
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
@@ -750,22 +807,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"after": {"event_ids": events_after, "token": end_token},
}
- @defer.inlineCallbacks
- def get_all_new_events_stream(self, from_id, current_id, limit):
+ async def get_all_new_events_stream(
+ self, from_id: int, current_id: int, limit: int
+ ) -> Tuple[int, List[EventBase]]:
"""Get all new events
Returns all events with from_id < stream_ordering <= current_id.
Args:
- from_id (int): the stream_ordering of the last event we processed
- current_id (int): the stream_ordering of the most recently processed event
- limit (int): the maximum number of events to return
+ from_id: the stream_ordering of the last event we processed
+ current_id: the stream_ordering of the most recently processed event
+ limit: the maximum number of events to return
Returns:
- Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where
- `next_id` is the next value to pass as `from_id` (it will either be the
- stream_ordering of the last returned event, or, if fewer than `limit` events
- were found, `current_id`.
+ A tuple of (next_id, events), where `next_id` is the next value to
+ pass as `from_id` (it will either be the stream_ordering of the
+ last returned event, or, if fewer than `limit` events were found,
+ the `current_id`).
"""
def get_all_new_events_stream_txn(txn):
@@ -787,63 +845,134 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows]
- upper_bound, event_ids = yield self.db.runInteraction(
+ upper_bound, event_ids = await self.db_pool.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
)
- events = yield self.get_events_as_list(event_ids)
+ events = await self.get_events_as_list(event_ids)
return upper_bound, events
- def get_federation_out_pos(self, typ):
- return self.db.simple_select_one_onecol(
+ async def get_federation_out_pos(self, typ: str) -> int:
+ if self._need_to_reset_federation_stream_positions:
+ await self.db_pool.runInteraction(
+ "_reset_federation_positions_txn", self._reset_federation_positions_txn
+ )
+ self._need_to_reset_federation_stream_positions = False
+
+ return await self.db_pool.simple_select_one_onecol(
table="federation_stream_position",
retcol="stream_id",
- keyvalues={"type": typ},
+ keyvalues={"type": typ, "instance_name": self._instance_name},
desc="get_federation_out_pos",
)
- def update_federation_out_pos(self, typ, stream_id):
- return self.db.simple_update_one(
+ async def update_federation_out_pos(self, typ: str, stream_id: int) -> None:
+ if self._need_to_reset_federation_stream_positions:
+ await self.db_pool.runInteraction(
+ "_reset_federation_positions_txn", self._reset_federation_positions_txn
+ )
+ self._need_to_reset_federation_stream_positions = False
+
+ await self.db_pool.simple_update_one(
table="federation_stream_position",
- keyvalues={"type": typ},
+ keyvalues={"type": typ, "instance_name": self._instance_name},
updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos",
)
- def has_room_changed_since(self, room_id, stream_id):
+ def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
+ """Fiddles with the `federation_stream_position` table to make it match
+ the configured federation sender instances during start up.
+ """
+
+ # The federation sender instances may have changed, so we need to
+ # massage the `federation_stream_position` table to have a row per type
+ # per instance sending federation. If there is a mismatch we update the
+ # table with the correct rows using the *minimum* stream ID seen. This
+ # may result in resending of events/EDUs to remote servers, but that is
+ # preferable to dropping them.
+
+ if not self._send_federation:
+ return
+
+ # Pull out the configured instances. If we don't have a shard config then
+ # we assume that we're the only instance sending.
+ configured_instances = self._federation_shard_config.instances
+ if not configured_instances:
+ configured_instances = [self._instance_name]
+ elif self._instance_name not in configured_instances:
+ return
+
+ instances_in_table = self.db_pool.simple_select_onecol_txn(
+ txn,
+ table="federation_stream_position",
+ keyvalues={},
+ retcol="instance_name",
+ )
+
+ if set(instances_in_table) == set(configured_instances):
+ # Nothing to do
+ return
+
+ sql = """
+ SELECT type, MIN(stream_id) FROM federation_stream_position
+ GROUP BY type
+ """
+ txn.execute(sql)
+ min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position
+
+ # Ensure we do actually have some values here
+ assert set(min_positions) == {"federation", "events"}
+
+ sql = """
+ DELETE FROM federation_stream_position
+ WHERE NOT (%s)
+ """
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "instance_name", configured_instances
+ )
+ txn.execute(sql % (clause,), args)
+
+ for typ, stream_id in min_positions.items():
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="federation_stream_position",
+ keyvalues={"type": typ, "instance_name": self._instance_name},
+ values={"stream_id": stream_id},
+ )
+
+ def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
def _paginate_room_events_txn(
self,
- txn,
- room_id,
- from_token,
- to_token=None,
- direction="b",
- limit=-1,
- event_filter=None,
- ):
+ txn: LoggingTransaction,
+ room_id: str,
+ from_token: RoomStreamToken,
+ to_token: Optional[RoomStreamToken] = None,
+ direction: str = "b",
+ limit: int = -1,
+ event_filter: Optional[Filter] = None,
+ ) -> Tuple[List[_EventDictReturn], str]:
"""Returns list of events before or after a given token.
Args:
txn
- room_id (str)
- from_token (RoomStreamToken): The token used to stream from
- to_token (RoomStreamToken|None): A token which if given limits the
- results to only those before
- direction(char): Either 'b' or 'f' to indicate whether we are
- paginating forwards or backwards from `from_key`.
- limit (int): The maximum number of events to return.
- event_filter (Filter|None): If provided filters the events to
+ room_id
+ from_token: The token used to stream from
+ to_token: A token which if given limits the results to only those before
+ direction: Either 'b' or 'f' to indicate whether we are paginating
+ forwards or backwards from `from_key`.
+ limit: The maximum number of events to return.
+ event_filter: If provided filters the events to
those that match the filter.
Returns:
- Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
- as a list of _EventDictReturn and a token that points to the end
- of the result set. If no events are returned then the end of the
- stream has been reached (i.e. there are no events between
- `from_token` and `to_token`), or `limit` is zero.
+ A list of _EventDictReturn and a token that points to the end of the
+ result set. If no events are returned then the end of the stream has
+ been reached (i.e. there are no events between `from_token` and
+ `to_token`), or `limit` is zero.
"""
assert int(limit) >= 0
@@ -927,35 +1056,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, str(next_token)
- @defer.inlineCallbacks
- def paginate_room_events(
- self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None
- ):
+ async def paginate_room_events(
+ self,
+ room_id: str,
+ from_key: str,
+ to_key: Optional[str] = None,
+ direction: str = "b",
+ limit: int = -1,
+ event_filter: Optional[Filter] = None,
+ ) -> Tuple[List[EventBase], str]:
"""Returns list of events before or after a given token.
Args:
- room_id (str)
- from_key (str): The token used to stream from
- to_key (str|None): A token which if given limits the results to
- only those before
- direction(char): Either 'b' or 'f' to indicate whether we are
- paginating forwards or backwards from `from_key`.
- limit (int): The maximum number of events to return.
- event_filter (Filter|None): If provided filters the events to
- those that match the filter.
+ room_id
+ from_key: The token used to stream from
+ to_key: A token which if given limits the results to only those before
+ direction: Either 'b' or 'f' to indicate whether we are paginating
+ forwards or backwards from `from_key`.
+ limit: The maximum number of events to return.
+ event_filter: If provided filters the events to those that match the filter.
Returns:
- tuple[list[FrozenEvent], str]: Returns the results as a list of
- events and a token that points to the end of the result set. If no
- events are returned then the end of the stream has been reached
- (i.e. there are no events between `from_key` and `to_key`).
+ The results as a list of events and a token that points to the end
+ of the result set. If no events are returned then the end of the
+ stream has been reached (i.e. there are no events between `from_key`
+ and `to_key`).
"""
from_key = RoomStreamToken.parse(from_key)
if to_key:
to_key = RoomStreamToken.parse(to_key)
- rows, token = yield self.db.runInteraction(
+ rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
@@ -966,7 +1098,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
@@ -976,8 +1108,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
class StreamStore(StreamWorkerStore):
- def get_room_max_stream_ordering(self):
+ def get_room_max_stream_ordering(self) -> int:
return self._stream_id_gen.get_current_token()
- def get_room_min_stream_ordering(self):
+ def get_room_min_stream_ordering(self) -> int:
return self._backfill_id_gen.get_current_token()
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/databases/main/tags.py
index 4219018302..96ffe26cc9 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,14 +15,12 @@
# limitations under the License.
import logging
+from typing import Dict, List, Tuple
-from six.moves import range
-
-from canonicaljson import json
-
-from twisted.internet import defer
-
-from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
+from synapse.storage._base import db_to_json
+from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.types import JsonDict
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -30,43 +28,53 @@ logger = logging.getLogger(__name__)
class TagsWorkerStore(AccountDataWorkerStore):
@cached()
- def get_tags_for_user(self, user_id):
+ async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
"""Get all the tags for a user.
Args:
- user_id(str): The user to get the tags for.
+ user_id: The user to get the tags for.
Returns:
- A deferred dict mapping from room_id strings to dicts mapping from
- tag strings to tag content.
+ A mapping from room_id strings to dicts mapping from tag strings to
+ tag content.
"""
- deferred = self.db.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
- @deferred.addCallback
- def tags_by_room(rows):
- tags_by_room = {}
- for row in rows:
- room_tags = tags_by_room.setdefault(row["room_id"], {})
- room_tags[row["tag"]] = json.loads(row["content"])
- return tags_by_room
+ tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]]
+ for row in rows:
+ room_tags = tags_by_room.setdefault(row["room_id"], {})
+ room_tags[row["tag"]] = db_to_json(row["content"])
+ return tags_by_room
- return deferred
+ async def get_all_updated_tags(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+ """Get updates for tags replication stream.
- @defer.inlineCallbacks
- def get_all_updated_tags(self, last_id, current_id, limit):
- """Get all the client tags that have changed on the server
Args:
- last_id(int): The position to fetch from.
- current_id(int): The position to fetch up to.
+ instance_name: The writer we want to fetch updates from. Unused
+ here since there is only ever one writer.
+ last_id: The token to fetch updates from. Exclusive.
+ current_id: The token to fetch updates up to. Inclusive.
+ limit: The requested limit for the number of rows to return. The
+ function may return more or fewer rows.
+
Returns:
- A deferred list of tuples of stream_id int, user_id string,
- room_id string, tag string and content string.
+ A tuple consisting of: the updates, a token to use to fetch
+ subsequent updates, and whether we returned fewer rows than exists
+ between the requested tokens due to the limit.
+
+ The token returned can be used in a subsequent call to this
+ function to get further updatees.
+
+ The updates are a list of 2-tuples of stream ID and the row data
"""
+
if last_id == current_id:
- return []
+ return [], current_id, False
def get_all_updated_tags_txn(txn):
sql = (
@@ -78,7 +86,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- tag_ids = yield self.db.runInteraction(
+ tag_ids = await self.db_pool.runInteraction(
"get_all_updated_tags", get_all_updated_tags_txn
)
@@ -89,35 +97,43 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (user_id, room_id))
tags = []
for tag, content in txn:
- tags.append(json.dumps(tag) + ":" + content)
+ tags.append(json_encoder.encode(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
- results.append((stream_id, user_id, room_id, tag_json))
+ results.append((stream_id, (user_id, room_id, tag_json)))
return results
batch_size = 50
results = []
for i in range(0, len(tag_ids), batch_size):
- tags = yield self.db.runInteraction(
+ tags = await self.db_pool.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
tag_ids[i : i + batch_size],
)
results.extend(tags)
- return results
+ limited = False
+ upto_token = current_id
+ if len(results) >= limit:
+ upto_token = results[-1][0]
+ limited = True
+
+ return results, upto_token, limited
- @defer.inlineCallbacks
- def get_updated_tags(self, user_id, stream_id):
+ async def get_updated_tags(
+ self, user_id: str, stream_id: int
+ ) -> Dict[str, Dict[str, JsonDict]]:
"""Get all the tags for the rooms where the tags have changed since the
given version
Args:
user_id(str): The user to get the tags for.
stream_id(int): The earliest update to get for the user.
+
Returns:
- A deferred dict mapping from room_id strings to lists of tag
- strings for all the rooms that changed since the stream_id token.
+ A mapping from room_id strings to lists of tag strings for all the
+ rooms that changed since the stream_id token.
"""
def get_updated_tags_txn(txn):
@@ -135,52 +151,58 @@ class TagsWorkerStore(AccountDataWorkerStore):
if not changed:
return {}
- room_ids = yield self.db.runInteraction(
+ room_ids = await self.db_pool.runInteraction(
"get_updated_tags", get_updated_tags_txn
)
results = {}
if room_ids:
- tags_by_room = yield self.get_tags_for_user(user_id)
+ tags_by_room = await self.get_tags_for_user(user_id)
for room_id in room_ids:
results[room_id] = tags_by_room.get(room_id, {})
return results
- def get_tags_for_room(self, user_id, room_id):
+ async def get_tags_for_room(
+ self, user_id: str, room_id: str
+ ) -> Dict[str, JsonDict]:
"""Get all the tags for the given room
+
Args:
- user_id(str): The user to get tags for
- room_id(str): The room to get tags for
+ user_id: The user to get tags for
+ room_id: The room to get tags for
+
Returns:
- A deferred list of string tags.
+ A mapping of tags to tag content.
"""
- return self.db.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
desc="get_tags_for_room",
- ).addCallback(
- lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows}
)
+ return {row["tag"]: db_to_json(row["content"]) for row in rows}
class TagsStore(TagsWorkerStore):
- @defer.inlineCallbacks
- def add_tag_to_room(self, user_id, room_id, tag, content):
+ async def add_tag_to_room(
+ self, user_id: str, room_id: str, tag: str, content: JsonDict
+ ) -> int:
"""Add a tag to a room for a user.
+
Args:
- user_id(str): The user to add a tag for.
- room_id(str): The room to add a tag for.
- tag(str): The tag name to add.
- content(dict): A json object to associate with the tag.
+ user_id: The user to add a tag for.
+ room_id: The room to add a tag for.
+ tag: The tag name to add.
+ content: A json object to associate with the tag.
+
Returns:
- A deferred that completes once the tag has been added.
+ The next account data ID.
"""
- content_json = json.dumps(content)
+ content_json = json_encoder.encode(content)
def add_tag_txn(txn, next_id):
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
@@ -188,19 +210,18 @@ class TagsStore(TagsWorkerStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with self._account_data_id_gen.get_next() as next_id:
- yield self.db.runInteraction("add_tag", add_tag_txn, next_id)
+ with await self._account_data_id_gen.get_next() as next_id:
+ await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
- result = self._account_data_id_gen.get_current_token()
- return result
+ return self._account_data_id_gen.get_current_token()
- @defer.inlineCallbacks
- def remove_tag_from_room(self, user_id, room_id, tag):
+ async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
"""Remove a tag from a room for a user.
+
Returns:
- A deferred that completes once the tag has been removed
+ The next account data ID.
"""
def remove_tag_txn(txn, next_id):
@@ -211,22 +232,23 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with self._account_data_id_gen.get_next() as next_id:
- yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id)
+ with await self._account_data_id_gen.get_next() as next_id:
+ await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
- result = self._account_data_id_gen.get_current_token()
- return result
+ return self._account_data_id_gen.get_current_token()
- def _update_revision_txn(self, txn, user_id, room_id, next_id):
+ def _update_revision_txn(
+ self, txn, user_id: str, room_id: str, next_id: int
+ ) -> None:
"""Update the latest revision of the tags for the given user and room.
Args:
txn: The database cursor
- user_id(str): The ID of the user.
- room_id(str): The ID of the room.
- next_id(int): The the revision to advance to.
+ user_id: The ID of the user.
+ room_id: The ID of the room.
+ next_id: The the revision to advance to.
"""
txn.call_after(
diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/databases/main/transactions.py
index a9bf457939..5b31aab700 100644
--- a/synapse/storage/data_stores/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -15,14 +15,14 @@
import logging
from collections import namedtuple
+from typing import Optional, Tuple
from canonicaljson import encode_canonical_json
-from twisted.internet import defer
-
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.types import JsonDict
from synapse.util.caches.expiringcache import ExpiringCache
db_binary_type = memoryview
@@ -46,7 +46,7 @@ class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(TransactionStore, self).__init__(database, db_conn, hs)
self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
@@ -57,21 +57,23 @@ class TransactionStore(SQLBaseStore):
expiry_ms=5 * 60 * 1000,
)
- def get_received_txn_response(self, transaction_id, origin):
+ async def get_received_txn_response(
+ self, transaction_id: str, origin: str
+ ) -> Optional[Tuple[int, JsonDict]]:
"""For an incoming transaction from a given origin, check if we have
already responded to it. If so, return the response code and response
body (as a dict).
Args:
- transaction_id (str)
- origin(str)
+ transaction_id
+ origin
Returns:
- tuple: None if we have not previously responded to
- this transaction or a 2-tuple of (int, dict)
+ None if we have not previously responded to this transaction or a
+ 2-tuple of (int, dict)
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_received_txn_response",
self._get_received_txn_response,
transaction_id,
@@ -79,7 +81,7 @@ class TransactionStore(SQLBaseStore):
)
def _get_received_txn_response(self, txn, transaction_id, origin):
- result = self.db.simple_select_one_txn(
+ result = self.db_pool.simple_select_one_txn(
txn,
table="received_transactions",
keyvalues={"transaction_id": transaction_id, "origin": origin},
@@ -100,20 +102,21 @@ class TransactionStore(SQLBaseStore):
else:
return None
- def set_received_txn_response(self, transaction_id, origin, code, response_dict):
- """Persist the response we returened for an incoming transaction, and
+ async def set_received_txn_response(
+ self, transaction_id: str, origin: str, code: int, response_dict: JsonDict
+ ) -> None:
+ """Persist the response we returned for an incoming transaction, and
should return for subsequent transactions with the same transaction_id
and origin.
Args:
- txn
- transaction_id (str)
- origin (str)
- code (int)
- response_json (str)
+ transaction_id: The incoming transaction ID.
+ origin: The origin server.
+ code: The response code.
+ response_dict: The response, to be encoded into JSON.
"""
- return self.db.simple_insert(
+ await self.db_pool.simple_insert(
table="received_transactions",
values={
"transaction_id": transaction_id,
@@ -126,8 +129,7 @@ class TransactionStore(SQLBaseStore):
desc="set_received_txn_response",
)
- @defer.inlineCallbacks
- def get_destination_retry_timings(self, destination):
+ async def get_destination_retry_timings(self, destination):
"""Gets the current retry timings (if any) for a given destination.
Args:
@@ -142,7 +144,7 @@ class TransactionStore(SQLBaseStore):
if result is not SENTINEL:
return result
- result = yield self.db.runInteraction(
+ result = await self.db_pool.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings,
destination,
@@ -154,7 +156,7 @@ class TransactionStore(SQLBaseStore):
return result
def _get_destination_retry_timings(self, txn, destination):
- result = self.db.simple_select_one_txn(
+ result = self.db_pool.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
@@ -167,21 +169,25 @@ class TransactionStore(SQLBaseStore):
else:
return None
- def set_destination_retry_timings(
- self, destination, failure_ts, retry_last_ts, retry_interval
- ):
+ async def set_destination_retry_timings(
+ self,
+ destination: str,
+ failure_ts: Optional[int],
+ retry_last_ts: int,
+ retry_interval: int,
+ ) -> None:
"""Sets the current retry timings for a given destination.
Both timings should be zero if retrying is no longer occuring.
Args:
- destination (str)
- failure_ts (int|None) - when the server started failing (ms since epoch)
- retry_last_ts (int) - time of last retry attempt in unix epoch ms
- retry_interval (int) - how long until next retry in ms
+ destination
+ failure_ts: when the server started failing (ms since epoch)
+ retry_last_ts: time of last retry attempt in unix epoch ms
+ retry_interval: how long until next retry in ms
"""
self._destination_retry_cache.pop(destination, None)
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings,
destination,
@@ -221,7 +227,7 @@ class TransactionStore(SQLBaseStore):
# We need to be careful here as the data may have changed from under us
# due to a worker setting the timings.
- prev_row = self.db.simple_select_one_txn(
+ prev_row = self.db_pool.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
@@ -230,7 +236,7 @@ class TransactionStore(SQLBaseStore):
)
if not prev_row:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="destinations",
values={
@@ -241,7 +247,7 @@ class TransactionStore(SQLBaseStore):
},
)
elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
"destinations",
keyvalues={"destination": destination},
@@ -257,13 +263,13 @@ class TransactionStore(SQLBaseStore):
"cleanup_transactions", self._cleanup_transactions
)
- def _cleanup_transactions(self):
+ async def _cleanup_transactions(self) -> None:
now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000
def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"_cleanup_transactions", _cleanup_transactions_txn
)
diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 1d8ee22fb1..b89668d561 100644
--- a/synapse/storage/data_stores/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -12,15 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import attr
-import synapse.util.stringutils as stringutils
from synapse.api.errors import StoreError
-from synapse.storage._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
+from synapse.util import json_encoder, stringutils
@attr.s
@@ -72,7 +72,7 @@ class UIAuthWorkerStore(SQLBaseStore):
StoreError if a unique session ID cannot be generated.
"""
# The clientdict gets stored as JSON.
- clientdict_json = json.dumps(clientdict)
+ clientdict_json = json_encoder.encode(clientdict)
# autogen a session ID and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
@@ -81,7 +81,7 @@ class UIAuthWorkerStore(SQLBaseStore):
session_id = stringutils.random_string(24)
try:
- await self.db.simple_insert(
+ await self.db_pool.simple_insert(
table="ui_auth_sessions",
values={
"session_id": session_id,
@@ -97,7 +97,7 @@ class UIAuthWorkerStore(SQLBaseStore):
return UIAuthSessionData(
session_id, clientdict, uri, method, description
)
- except self.db.engine.module.IntegrityError:
+ except self.db_pool.engine.module.IntegrityError:
attempts += 1
raise StoreError(500, "Couldn't generate a session ID.")
@@ -111,14 +111,14 @@ class UIAuthWorkerStore(SQLBaseStore):
Raises:
StoreError if the session is not found.
"""
- result = await self.db.simple_select_one(
+ result = await self.db_pool.simple_select_one(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("clientdict", "uri", "method", "description"),
desc="get_ui_auth_session",
)
- result["clientdict"] = json.loads(result["clientdict"])
+ result["clientdict"] = db_to_json(result["clientdict"])
return UIAuthSessionData(session_id, **result)
@@ -140,13 +140,13 @@ class UIAuthWorkerStore(SQLBaseStore):
# Note that we need to allow for the same stage to complete multiple
# times here so that registration is idempotent.
try:
- await self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id, "stage_type": stage_type},
- values={"result": json.dumps(result)},
+ values={"result": json_encoder.encode(result)},
desc="mark_ui_auth_stage_complete",
)
- except self.db.engine.module.IntegrityError:
+ except self.db_pool.engine.module.IntegrityError:
raise StoreError(400, "Unknown session ID: %s" % (session_id,))
async def get_completed_ui_auth_stages(
@@ -162,13 +162,13 @@ class UIAuthWorkerStore(SQLBaseStore):
that auth-type.
"""
results = {}
- for row in await self.db.simple_select_list(
+ for row in await self.db_pool.simple_select_list(
table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id},
retcols=("stage_type", "result"),
desc="get_completed_ui_auth_stages",
):
- results[row["stage_type"]] = json.loads(row["result"])
+ results[row["stage_type"]] = db_to_json(row["result"])
return results
@@ -184,9 +184,9 @@ class UIAuthWorkerStore(SQLBaseStore):
The dictionary from the client root level, not the 'auth' key.
"""
# The clientdict gets stored as JSON.
- clientdict_json = json.dumps(clientdict)
+ clientdict_json = json_encoder.encode(clientdict)
- self.db.simple_update_one(
+ await self.db_pool.simple_update_one(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
updatevalues={"clientdict": clientdict_json},
@@ -206,7 +206,7 @@ class UIAuthWorkerStore(SQLBaseStore):
Raises:
StoreError if the session cannot be found.
"""
- await self.db.runInteraction(
+ await self.db_pool.runInteraction(
"set_ui_auth_session_data",
self._set_ui_auth_session_data_txn,
session_id,
@@ -214,24 +214,26 @@ class UIAuthWorkerStore(SQLBaseStore):
value,
)
- def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
+ def _set_ui_auth_session_data_txn(
+ self, txn: LoggingTransaction, session_id: str, key: str, value: Any
+ ):
# Get the current value.
- result = self.db.simple_select_one_txn(
+ result = self.db_pool.simple_select_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
- )
+ ) # type: Dict[str, Any] # type: ignore
# Update it and add it back to the database.
- serverdict = json.loads(result["serverdict"])
+ serverdict = db_to_json(result["serverdict"])
serverdict[key] = value
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
- updatevalues={"serverdict": json.dumps(serverdict)},
+ updatevalues={"serverdict": json_encoder.encode(serverdict)},
)
async def get_ui_auth_session_data(
@@ -247,20 +249,48 @@ class UIAuthWorkerStore(SQLBaseStore):
Raises:
StoreError if the session cannot be found.
"""
- result = await self.db.simple_select_one(
+ result = await self.db_pool.simple_select_one(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
desc="get_ui_auth_session_data",
)
- serverdict = json.loads(result["serverdict"])
+ serverdict = db_to_json(result["serverdict"])
return serverdict.get(key, default)
+ async def add_user_agent_ip_to_ui_auth_session(
+ self, session_id: str, user_agent: str, ip: str,
+ ):
+ """Add the given user agent / IP to the tracking table
+ """
+ await self.db_pool.simple_upsert(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
+ values={},
+ desc="add_user_agent_ip_to_ui_auth_session",
+ )
+
+ async def get_user_agents_ips_to_ui_auth_session(
+ self, session_id: str,
+ ) -> List[Tuple[str, str]]:
+ """Get the given user agents / IPs used during the ui auth process
+
+ Returns:
+ List of user_agent/ip pairs
+ """
+ rows = await self.db_pool.simple_select_list(
+ table="ui_auth_sessions_ips",
+ keyvalues={"session_id": session_id},
+ retcols=("user_agent", "ip"),
+ desc="get_user_agents_ips_to_ui_auth_session",
+ )
+ return [(row["user_agent"], row["ip"]) for row in rows]
+
class UIAuthStore(UIAuthWorkerStore):
- def delete_old_ui_auth_sessions(self, expiration_time: int):
+ async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
Remove sessions which were last used earlier than the expiration time.
@@ -269,20 +299,31 @@ class UIAuthStore(UIAuthWorkerStore):
This is an epoch time in milliseconds.
"""
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"delete_old_ui_auth_sessions",
self._delete_old_ui_auth_sessions_txn,
expiration_time,
)
- def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
+ def _delete_old_ui_auth_sessions_txn(
+ self, txn: LoggingTransaction, expiration_time: int
+ ):
# Get the expired sessions.
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
txn.execute(sql, [expiration_time])
session_ids = [r[0] for r in txn.fetchall()]
+ # Delete the corresponding IP/user agents.
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="ui_auth_sessions_ips",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={},
+ )
+
# Delete the corresponding completed credentials.
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="ui_auth_sessions_credentials",
column="session_id",
@@ -291,7 +332,7 @@ class UIAuthStore(UIAuthWorkerStore):
)
# Finally, delete the sessions.
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="ui_auth_sessions",
column="session_id",
diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 6b8130bf0f..f2f9a5799a 100644
--- a/synapse/storage/data_stores/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -15,13 +15,12 @@
import logging
import re
-
-from twisted.internet import defer
+from typing import Any, Dict, Iterable, Optional, Set, Tuple
from synapse.api.constants import EventTypes, JoinRules
-from synapse.storage.data_stores.main.state import StateFilter
-from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.state import StateFilter
+from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached
@@ -38,29 +37,28 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.server_name = hs.hostname
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"populate_user_directory_createtables",
self._populate_user_directory_createtables,
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"populate_user_directory_process_rooms",
self._populate_user_directory_process_rooms,
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"populate_user_directory_process_users",
self._populate_user_directory_process_users,
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"populate_user_directory_cleanup", self._populate_user_directory_cleanup
)
- @defer.inlineCallbacks
- def _populate_user_directory_createtables(self, progress, batch_size):
+ async def _populate_user_directory_createtables(self, progress, batch_size):
# Get all the rooms that we want to process.
def _make_staging_area(txn):
@@ -85,7 +83,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"""
txn.execute(sql)
rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
- self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
+ self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
del rooms
# If search all users is on, get all the users we want to add.
@@ -100,43 +98,45 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.execute("SELECT name FROM users")
users = [{"user_id": x[0]} for x in txn.fetchall()]
- self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+ self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
- new_pos = yield self.get_max_stream_id_in_current_state_deltas()
- yield self.db.runInteraction(
+ new_pos = await self.get_max_stream_id_in_current_state_deltas()
+ await self.db_pool.runInteraction(
"populate_user_directory_temp_build", _make_staging_area
)
- yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
+ await self.db_pool.simple_insert(
+ TEMP_TABLE + "_position", {"position": new_pos}
+ )
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
"populate_user_directory_createtables"
)
return 1
- @defer.inlineCallbacks
- def _populate_user_directory_cleanup(self, progress, batch_size):
+ async def _populate_user_directory_cleanup(self, progress, batch_size):
"""
Update the user directory stream position, then clean up the old tables.
"""
- position = yield self.db.simple_select_one_onecol(
+ position = await self.db_pool.simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position"
)
- yield self.update_user_directory_stream_pos(position)
+ await self.update_user_directory_stream_pos(position)
def _delete_staging_area(txn):
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"populate_user_directory_cleanup", _delete_staging_area
)
- yield self.db.updates._end_background_update("populate_user_directory_cleanup")
+ await self.db_pool.updates._end_background_update(
+ "populate_user_directory_cleanup"
+ )
return 1
- @defer.inlineCallbacks
- def _populate_user_directory_process_rooms(self, progress, batch_size):
+ async def _populate_user_directory_process_rooms(self, progress, batch_size):
"""
Args:
progress (dict)
@@ -147,7 +147,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# If we don't have progress filed, delete everything.
if not progress:
- yield self.delete_all_from_user_dir()
+ await self.delete_all_from_user_dir()
def _get_next_batch(txn):
# Only fetch 250 rooms, so we don't fetch too many at once, even
@@ -172,13 +172,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return rooms_to_work_on
- rooms_to_work_on = yield self.db.runInteraction(
+ rooms_to_work_on = await self.db_pool.runInteraction(
"populate_user_directory_temp_read", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
"populate_user_directory_process_rooms"
)
return 1
@@ -191,19 +191,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
processed_event_count = 0
for room_id, event_count in rooms_to_work_on:
- is_in_room = yield self.is_host_joined(room_id, self.server_name)
+ is_in_room = await self.is_host_joined(room_id, self.server_name)
if is_in_room:
- is_public = yield self.is_room_world_readable_or_publicly_joinable(
+ is_public = await self.is_room_world_readable_or_publicly_joinable(
room_id
)
- users_with_profile = yield state.get_current_users_in_room(room_id)
+ users_with_profile = await state.get_current_users_in_room(room_id)
user_ids = set(users_with_profile)
# Update each user in the user directory.
for user_id, profile in users_with_profile.items():
- yield self.update_profile_in_user_dir(
+ await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
@@ -217,7 +217,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
to_insert.add(user_id)
if to_insert:
- yield self.add_users_in_public_rooms(room_id, to_insert)
+ await self.add_users_in_public_rooms(room_id, to_insert)
to_insert.clear()
else:
for user_id in user_ids:
@@ -237,22 +237,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# If it gets too big, stop and write to the database
# to prevent storing too much in RAM.
if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET:
- yield self.add_users_who_share_private_room(
+ await self.add_users_who_share_private_room(
room_id, to_insert
)
to_insert.clear()
if to_insert:
- yield self.add_users_who_share_private_room(room_id, to_insert)
+ await self.add_users_who_share_private_room(room_id, to_insert)
to_insert.clear()
# We've finished a room. Delete it from the table.
- yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
+ await self.db_pool.simple_delete_one(
+ TEMP_TABLE + "_rooms", {"room_id": room_id}
+ )
# Update the remaining counter.
progress["remaining"] -= 1
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"populate_user_directory",
- self.db.updates._background_update_progress_txn,
+ self.db_pool.updates._background_update_progress_txn,
"populate_user_directory_process_rooms",
progress,
)
@@ -265,13 +267,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return processed_event_count
- @defer.inlineCallbacks
- def _populate_user_directory_process_users(self, progress, batch_size):
+ async def _populate_user_directory_process_users(self, progress, batch_size):
"""
If search_all_users is enabled, add all of the users to the user directory.
"""
if not self.hs.config.user_directory_search_all_users:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
"populate_user_directory_process_users"
)
return 1
@@ -297,13 +298,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return users_to_work_on
- users_to_work_on = yield self.db.runInteraction(
+ users_to_work_on = await self.db_pool.runInteraction(
"populate_user_directory_temp_read", _get_next_batch
)
# No more users -- complete the transaction.
if not users_to_work_on:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
"populate_user_directory_process_users"
)
return 1
@@ -314,26 +315,27 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
for user_id in users_to_work_on:
- profile = yield self.get_profileinfo(get_localpart_from_id(user_id))
- yield self.update_profile_in_user_dir(
+ profile = await self.get_profileinfo(get_localpart_from_id(user_id))
+ await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
# We've finished processing a user. Delete it from the table.
- yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
+ await self.db_pool.simple_delete_one(
+ TEMP_TABLE + "_users", {"user_id": user_id}
+ )
# Update the remaining counter.
progress["remaining"] -= 1
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"populate_user_directory",
- self.db.updates._background_update_progress_txn,
+ self.db_pool.updates._background_update_progress_txn,
"populate_user_directory_process_users",
progress,
)
return len(users_to_work_on)
- @defer.inlineCallbacks
- def is_room_world_readable_or_publicly_joinable(self, room_id):
+ async def is_room_world_readable_or_publicly_joinable(self, room_id):
"""Check if the room is either world_readable or publically joinable
"""
@@ -343,33 +345,40 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
(EventTypes.RoomHistoryVisibility, ""),
)
- current_state_ids = yield self.get_filtered_current_state_ids(
+ current_state_ids = await self.get_filtered_current_state_ids(
room_id, StateFilter.from_types(types_to_filter)
)
join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
if join_rules_id:
- join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
+ join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
return True
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id:
- hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
+ hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
if hist_vis_ev:
if hist_vis_ev.content.get("history_visibility") == "world_readable":
return True
return False
- def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
+ async def update_profile_in_user_dir(
+ self, user_id: str, display_name: str, avatar_url: str
+ ) -> None:
"""
Update or add a user's profile in the user directory.
"""
+ # If the display name or avatar URL are unexpected types, overwrite them.
+ if not isinstance(display_name, str):
+ display_name = None
+ if not isinstance(avatar_url, str):
+ avatar_url = None
def _update_profile_in_user_dir_txn(txn):
- new_entry = self.db.simple_upsert_txn(
+ new_entry = self.db_pool.simple_upsert_txn(
txn,
table="user_directory",
keyvalues={"user_id": user_id},
@@ -443,7 +452,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
elif isinstance(self.database_engine, Sqlite3Engine):
value = "%s %s" % (user_id, display_name) if display_name else user_id
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="user_directory_search",
keyvalues={"user_id": user_id},
@@ -456,21 +465,23 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
- def add_users_who_share_private_room(self, room_id, user_id_tuples):
+ async def add_users_who_share_private_room(
+ self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
+ ) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
- room_id (str)
- user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+ room_id
+ user_id_tuples: iterable of 2-tuple of user IDs.
"""
def _add_users_who_share_room_txn(txn):
- self.db.simple_upsert_many_txn(
+ self.db_pool.simple_upsert_many_txn(
txn,
table="users_who_share_private_rooms",
key_names=["user_id", "other_user_id", "room_id"],
@@ -482,22 +493,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
- def add_users_in_public_rooms(self, room_id, user_ids):
+ async def add_users_in_public_rooms(
+ self, room_id: str, user_ids: Iterable[str]
+ ) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
- room_id (str)
- user_ids (list[str])
+ room_id
+ user_ids
"""
def _add_users_in_public_rooms_txn(txn):
- self.db.simple_upsert_many_txn(
+ self.db_pool.simple_upsert_many_txn(
txn,
table="users_in_public_rooms",
key_names=["user_id", "room_id"],
@@ -506,11 +519,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
)
- def delete_all_from_user_dir(self):
+ async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory
"""
@@ -521,13 +534,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@cached()
- def get_user_in_directory(self, user_id):
- return self.db.simple_select_one(
+ async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
+ return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
@@ -535,8 +548,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
desc="get_user_in_directory",
)
- def update_user_directory_stream_pos(self, stream_id):
- return self.db.simple_update_one(
+ async def update_user_directory_stream_pos(self, stream_id: str) -> None:
+ await self.db_pool.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
@@ -550,47 +563,48 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(UserDirectoryStore, self).__init__(database, db_conn, hs)
- def remove_from_user_dir(self, user_id):
+ async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="user_directory_search", keyvalues={"user_id": user_id}
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id},
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
+ await self.db_pool.runInteraction(
+ "remove_from_user_dir", _remove_from_user_dir_txn
+ )
- @defer.inlineCallbacks
- def get_users_in_dir_due_to_room(self, room_id):
+ async def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
- user_ids_share_pub = yield self.db.simple_select_onecol(
+ user_ids_share_pub = await self.db_pool.simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
- user_ids_share_priv = yield self.db.simple_select_onecol(
+ user_ids_share_priv = await self.db_pool.simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"room_id": room_id},
retcol="other_user_id",
@@ -602,39 +616,38 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return user_ids
- def remove_user_who_share_room(self, user_id, room_id):
+ async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
"""
Deletes entries in the users_who_share_*_rooms table. The first
user should be a local user.
Args:
- user_id (str)
- room_id (str)
+ user_id
+ room_id
"""
def _remove_user_who_share_room_txn(txn):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id, "room_id": room_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id, "room_id": room_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="users_in_public_rooms",
keyvalues={"user_id": user_id, "room_id": room_id},
)
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
- @defer.inlineCallbacks
- def get_user_dir_rooms_user_is_in(self, user_id):
+ async def get_user_dir_rooms_user_is_in(self, user_id):
"""
Returns the rooms that a user is in.
@@ -644,14 +657,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
Returns:
list: user_id
"""
- rows = yield self.db.simple_select_onecol(
+ rows = await self.db_pool.simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
desc="get_rooms_user_is_in",
)
- pub_rows = yield self.db.simple_select_onecol(
+ pub_rows = await self.db_pool.simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
@@ -662,42 +675,57 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- @defer.inlineCallbacks
- def get_rooms_in_common_for_users(self, user_id, other_user_id):
- """Given two user_ids find out the list of rooms they share.
+ @cached()
+ async def get_shared_rooms_for_users(
+ self, user_id: str, other_user_id: str
+ ) -> Set[str]:
"""
- sql = """
- SELECT room_id FROM (
- SELECT c.room_id FROM current_state_events AS c
- INNER JOIN room_memberships AS m USING (event_id)
- WHERE type = 'm.room.member'
- AND m.membership = 'join'
- AND state_key = ?
- ) AS f1 INNER JOIN (
- SELECT c.room_id FROM current_state_events AS c
- INNER JOIN room_memberships AS m USING (event_id)
- WHERE type = 'm.room.member'
- AND m.membership = 'join'
- AND state_key = ?
- ) f2 USING (room_id)
+ Returns the rooms that a local user shares with another local or remote user.
+
+ Args:
+ user_id: The MXID of a local user
+ other_user_id: The MXID of the other user
+
+ Returns:
+ A set of room ID's that the users share.
"""
- rows = yield self.db.execute(
- "get_rooms_in_common_for_users", None, sql, user_id, other_user_id
+ def _get_shared_rooms_for_users_txn(txn):
+ txn.execute(
+ """
+ SELECT p1.room_id
+ FROM users_in_public_rooms as p1
+ INNER JOIN users_in_public_rooms as p2
+ ON p1.room_id = p2.room_id
+ AND p1.user_id = ?
+ AND p2.user_id = ?
+ UNION
+ SELECT room_id
+ FROM users_who_share_private_rooms
+ WHERE
+ user_id = ?
+ AND other_user_id = ?
+ """,
+ (user_id, other_user_id, user_id, other_user_id),
+ )
+ rows = self.db_pool.cursor_to_dict(txn)
+ return rows
+
+ rows = await self.db_pool.runInteraction(
+ "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
)
- return [room_id for room_id, in rows]
+ return {row["room_id"] for row in rows}
- def get_user_directory_stream_pos(self):
- return self.db.simple_select_one_onecol(
+ async def get_user_directory_stream_pos(self) -> int:
+ return await self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",
desc="get_user_directory_stream_pos",
)
- @defer.inlineCallbacks
- def search_user_dir(self, user_id, search_term, limit):
+ async def search_user_dir(self, user_id, search_term, limit):
"""Searches for users in directory
Returns:
@@ -794,8 +822,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
- results = yield self.db.execute(
- "search_user_dir", self.db.cursor_to_dict, sql, *args
+ results = await self.db_pool.execute(
+ "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
)
limited = len(results) > limit
diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index ec6b8a4ffd..2f7c95fc74 100644
--- a/synapse/storage/data_stores/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -13,35 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import operator
-
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
class UserErasureWorkerStore(SQLBaseStore):
@cached()
- def is_user_erased(self, user_id):
+ async def is_user_erased(self, user_id: str) -> bool:
"""
Check if the given user id has requested erasure
Args:
- user_id (str): full user id to check
+ user_id: full user id to check
Returns:
- Deferred[bool]: True if the user has requested erasure
+ True if the user has requested erasure
"""
- return self.db.simple_select_onecol(
+ result = await self.db_pool.simple_select_onecol(
table="erased_users",
keyvalues={"user_id": user_id},
retcol="1",
desc="is_user_erased",
- ).addCallback(operator.truth)
+ )
+ return bool(result)
- @cachedList(
- cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
- )
- def are_users_erased(self, user_ids):
+ @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
+ async def are_users_erased(self, user_ids):
"""
Checks which users in a list have requested erasure
@@ -49,14 +46,14 @@ class UserErasureWorkerStore(SQLBaseStore):
user_ids (iterable[str]): full user id to check
Returns:
- Deferred[dict[str, bool]]:
+ dict[str, bool]:
for each user, whether the user has requested erasure.
"""
# this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
@@ -65,16 +62,15 @@ class UserErasureWorkerStore(SQLBaseStore):
)
erased_users = {row["user_id"] for row in rows}
- res = {u: u in erased_users for u in user_ids}
- return res
+ return {u: u in erased_users for u in user_ids}
class UserErasureStore(UserErasureWorkerStore):
- def mark_user_erased(self, user_id):
+ async def mark_user_erased(self, user_id: str) -> None:
"""Indicate that user_id wishes their message history to be erased.
Args:
- user_id (str): full user_id to be erased
+ user_id: full user_id to be erased
"""
def f(txn):
@@ -88,4 +84,26 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
- return self.db.runInteraction("mark_user_erased", f)
+ await self.db_pool.runInteraction("mark_user_erased", f)
+
+ async def mark_user_not_erased(self, user_id: str) -> None:
+ """Indicate that user_id is no longer erased.
+
+ Args:
+ user_id: full user_id to be un-erased
+ """
+
+ def f(txn):
+ # first check if they are already in the list
+ txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
+ if not txn.fetchone():
+ return
+
+ # They are there, delete them.
+ self.simple_delete_one_txn(
+ txn, "erased_users", keyvalues={"user_id": user_id}
+ )
+
+ self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
+
+ await self.db_pool.runInteraction("mark_user_not_erased", f)
diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/databases/state/__init__.py
index 86e09f6229..c90d022899 100644
--- a/synapse/storage/data_stores/state/__init__.py
+++ b/synapse/storage/databases/state/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.state.store import StateGroupDataStore # noqa: F401
+from synapse.storage.databases.state.store import StateGroupDataStore # noqa: F401
diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index ff000bc9ec..139085b672 100644
--- a/synapse/storage/data_stores/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -15,12 +15,8 @@
import logging
-from six import iteritems
-
-from twisted.internet import defer
-
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
@@ -64,7 +60,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
count = 0
while next_group:
- next_group = self.db.simple_select_one_onecol_txn(
+ next_group = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
@@ -167,7 +163,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
):
break
- next_group = self.db.simple_select_one_onecol_txn(
+ next_group = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
@@ -184,24 +180,23 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME,
index_name="state_groups_room_id_idx",
table="state_groups",
columns=["room_id"],
)
- @defer.inlineCallbacks
- def _background_deduplicate_state(self, progress, batch_size):
+ async def _background_deduplicate_state(self, progress, batch_size):
"""This background update will slowly deduplicate state by reencoding
them as deltas.
"""
@@ -214,7 +209,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
if max_group is None:
- rows = yield self.db.execute(
+ rows = await self.db_pool.execute(
"_background_deduplicate_state",
None,
"SELECT coalesce(max(id), 0) FROM state_groups",
@@ -280,17 +275,17 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
delta_state = {
key: value
- for key, value in iteritems(curr_state)
+ for key, value in curr_state.items()
if prev_state.get(key, None) != value
}
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": state_group},
)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="state_group_edges",
values={
@@ -299,13 +294,13 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": state_group},
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -316,7 +311,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in iteritems(delta_state)
+ for key, state_id in delta_state.items()
],
)
@@ -326,25 +321,24 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
"max_group": max_group,
}
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
)
return False, batch_size
- finished, result = yield self.db.runInteraction(
+ finished, result = await self.db_pool.runInteraction(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
)
if finished:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
)
return result * BATCH_SIZE_SCALE_FACTOR
- @defer.inlineCallbacks
- def _background_index_state(self, progress, batch_size):
+ async def _background_index_state(self, progress, batch_size):
def reindex_txn(conn):
conn.rollback()
if isinstance(self.database_engine, PostgresEngine):
@@ -367,8 +361,10 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
)
txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
- yield self.db.runWithConnection(reindex_txn)
+ await self.db_pool.runWithConnection(reindex_txn)
- yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
+ await self.db_pool.updates._end_background_update(
+ self.STATE_GROUP_INDEX_UPDATE_NAME
+ )
return 1
diff --git a/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql b/synapse/storage/databases/state/schema/delta/23/drop_state_index.sql
index ae09fa0065..ae09fa0065 100644
--- a/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
+++ b/synapse/storage/databases/state/schema/delta/23/drop_state_index.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql b/synapse/storage/databases/state/schema/delta/30/state_stream.sql
index e85699e82e..e85699e82e 100644
--- a/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
+++ b/synapse/storage/databases/state/schema/delta/30/state_stream.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql
index 1450313bfa..1450313bfa 100644
--- a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql
+++ b/synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql b/synapse/storage/databases/state/schema/delta/35/add_state_index.sql
index 33980d02f0..33980d02f0 100644
--- a/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
+++ b/synapse/storage/databases/state/schema/delta/35/add_state_index.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/35/state.sql b/synapse/storage/databases/state/schema/delta/35/state.sql
index 0f1fa68a89..0f1fa68a89 100644
--- a/synapse/storage/data_stores/state/schema/delta/35/state.sql
+++ b/synapse/storage/databases/state/schema/delta/35/state.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql b/synapse/storage/databases/state/schema/delta/35/state_dedupe.sql
index 97e5067ef4..97e5067ef4 100644
--- a/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
+++ b/synapse/storage/databases/state/schema/delta/35/state_dedupe.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py b/synapse/storage/databases/state/schema/delta/47/state_group_seq.py
index 9fd1ccf6f7..9fd1ccf6f7 100644
--- a/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
+++ b/synapse/storage/databases/state/schema/delta/47/state_group_seq.py
diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql
index 7916ef18b2..7916ef18b2 100644
--- a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql
+++ b/synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/databases/state/schema/full_schemas/54/full.sql
index 35f97d6b3d..35f97d6b3d 100644
--- a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql
+++ b/synapse/storage/databases/state/schema/full_schemas/54/full.sql
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres
index fcd926c9fb..fcd926c9fb 100644
--- a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres
+++ b/synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/databases/state/store.py
index f3ad1e4369..e924f1ca3b 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -17,16 +17,13 @@ import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Set, Tuple
-from six import iteritems
-from six.moves import range
-
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.state import StateFilter
+from synapse.storage.types import Cursor
+from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import StateMap
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -54,7 +51,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""A data store for fetching/storing state groups.
"""
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(StateGroupDataStore, self).__init__(database, db_conn, hs)
# Originally the state store used a single DictionaryCache to cache the
@@ -95,8 +92,16 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"*stateGroupMembersCache*", 500000,
)
+ def get_max_state_group_txn(txn: Cursor):
+ txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+ return txn.fetchone()[0]
+
+ self._state_group_seq_gen = build_sequence_generator(
+ self.database_engine, get_max_state_group_txn, "state_group_id_seq"
+ )
+
@cached(max_entries=10000, iterable=True)
- def get_state_group_delta(self, state_group):
+ async def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
the old and the new.
@@ -105,7 +110,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""
def _get_state_group_delta_txn(txn):
- prev_group = self.db.simple_select_one_onecol_txn(
+ prev_group = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": state_group},
@@ -116,7 +121,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
if not prev_group:
return _GetStateGroupDelta(None, None)
- delta_ids = self.db.simple_select_list_txn(
+ delta_ids = self.db_pool.simple_select_list_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": state_group},
@@ -128,14 +133,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
)
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_state_group_delta", _get_state_group_delta_txn
)
- @defer.inlineCallbacks
- def _get_state_groups_from_groups(
+ async def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
- ):
+ ) -> Dict[int, StateMap[str]]:
"""Returns the state groups for a given set of groups from the
database, filtering on types of state events.
@@ -144,13 +148,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
+ Dict of state group to state map.
"""
results = {}
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
- res = yield self.db.runInteraction(
+ res = await self.db_pool.runInteraction(
"_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn,
chunk,
@@ -199,10 +203,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
- @defer.inlineCallbacks
- def _get_state_for_groups(
+ async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> Dict[int, StateMap[str]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
@@ -212,7 +215,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
+ Dict of state group to state map.
"""
member_filter, non_member_filter = state_filter.get_member_split()
@@ -221,14 +224,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
(
non_member_state,
incomplete_groups_nm,
- ) = yield self._get_state_for_groups_using_cache(
+ ) = self._get_state_for_groups_using_cache(
groups, self._state_group_cache, state_filter=non_member_filter
)
- (
- member_state,
- incomplete_groups_m,
- ) = yield self._get_state_for_groups_using_cache(
+ (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
groups, self._state_group_members_cache, state_filter=member_filter
)
@@ -249,7 +249,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()
- group_to_state_dict = yield self._get_state_groups_from_groups(
+ group_to_state_dict = await self._get_state_groups_from_groups(
list(incomplete_groups), state_filter=db_state_filter
)
@@ -263,7 +263,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# And finally update the result dict, by filtering out any extra
# stuff we pulled out of the database.
- for group, group_state_dict in iteritems(group_to_state_dict):
+ for group, group_state_dict in group_to_state_dict.items():
# We just replace any existing entries, as we will have loaded
# everything we need from the database anyway.
state[group] = state_filter.filter_state(group_state_dict)
@@ -341,11 +341,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
else:
non_member_types = non_member_filter.concrete_types()
- for group, group_state_dict in iteritems(group_to_state_dict):
+ for group, group_state_dict in group_to_state_dict.items():
state_dict_members = {}
state_dict_non_members = {}
- for k, v in iteritems(group_state_dict):
+ for k, v in group_state_dict.items():
if k[0] == EventTypes.Member:
state_dict_members[k] = v
else:
@@ -365,9 +365,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
fetched_keys=non_member_types,
)
- def store_state_group(
+ async def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
- ):
+ ) -> int:
"""Store a new set of state, returning a newly assigned state group.
Args:
@@ -381,7 +381,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
to event_id.
Returns:
- Deferred[int]: The state group ID
+ The state group ID
"""
def _store_state_group_txn(txn):
@@ -389,9 +389,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# AFAIK, this can never happen
raise Exception("current_state_ids cannot be None")
- state_group = self.database_engine.get_next_state_group_id(txn)
+ state_group = self._state_group_seq_gen.get_next_id_txn(txn)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="state_groups",
values={"id": state_group, "room_id": room_id, "event_id": event_id},
@@ -400,7 +400,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if prev_group:
- is_in_db = self.db.simple_select_one_onecol_txn(
+ is_in_db = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
@@ -415,13 +415,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="state_group_edges",
values={"state_group": state_group, "prev_state_group": prev_group},
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -432,11 +432,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in iteritems(delta_ids)
+ for key, state_id in delta_ids.items()
],
)
else:
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -447,7 +447,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in iteritems(current_state_ids)
+ for key, state_id in current_state_ids.items()
],
)
@@ -458,7 +458,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
current_member_state_ids = {
s: ev
- for (s, ev) in iteritems(current_state_ids)
+ for (s, ev) in current_state_ids.items()
if s[0] == EventTypes.Member
}
txn.call_after(
@@ -470,7 +470,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
current_non_member_state_ids = {
s: ev
- for (s, ev) in iteritems(current_state_ids)
+ for (s, ev) in current_state_ids.items()
if s[0] != EventTypes.Member
}
txn.call_after(
@@ -482,11 +482,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_group
- return self.db.runInteraction("store_state_group", _store_state_group_txn)
+ return await self.db_pool.runInteraction(
+ "store_state_group", _store_state_group_txn
+ )
- def purge_unreferenced_state_groups(
+ async def purge_unreferenced_state_groups(
self, room_id: str, state_groups_to_delete
- ) -> defer.Deferred:
+ ) -> None:
"""Deletes no longer referenced state groups and de-deltas any state
groups that reference them.
@@ -497,7 +499,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
to delete.
"""
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"purge_unreferenced_state_groups",
self._purge_unreferenced_state_groups,
room_id,
@@ -509,7 +511,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"[purge] found %i state groups to delete", len(state_groups_to_delete)
)
- rows = self.db.simple_select_many_txn(
+ rows = self.db_pool.simple_select_many_txn(
txn,
table="state_group_edges",
column="prev_state_group",
@@ -536,15 +538,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
curr_state = curr_state[sg]
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="state_groups_state", keyvalues={"state_group": sg}
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="state_group_edges", keyvalues={"state_group": sg}
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -555,7 +557,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in iteritems(curr_state)
+ for key, state_id in curr_state.items()
],
)
@@ -569,19 +571,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
((sg,) for sg in state_groups_to_delete),
)
- @defer.inlineCallbacks
- def get_previous_state_groups(self, state_groups):
+ async def get_previous_state_groups(
+ self, state_groups: Iterable[int]
+ ) -> Dict[int, int]:
"""Fetch the previous groups of the given state groups.
Args:
- state_groups (Iterable[int])
+ state_groups
Returns:
- Deferred[dict[int, int]]: mapping from state group to previous
- state group.
+ A mapping from state group to previous state group.
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="state_group_edges",
column="prev_state_group",
iterable=state_groups,
@@ -592,7 +594,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return {row["state_group"]: row["prev_state_group"] for row in rows}
- def purge_room_state(self, room_id, state_groups_to_delete):
+ async def purge_room_state(self, room_id, state_groups_to_delete):
"""Deletes all record of a room from state tables
Args:
@@ -600,7 +602,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete (list[int]): State groups to delete
"""
- return self.db.runInteraction(
+ await self.db_pool.runInteraction(
"purge_room_state",
self._purge_room_state_txn,
room_id,
@@ -611,7 +613,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# first we have to delete the state groups states
logger.info("[purge] removing %s from state_groups_state", room_id)
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="state_groups_state",
column="state_group",
@@ -622,7 +624,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# ... and the state group edges
logger.info("[purge] removing %s from state_group_edges", room_id)
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="state_group_edges",
column="state_group",
@@ -633,7 +635,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# ... and the state groups
logger.info("[purge] removing %s from state_groups", room_id)
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="state_groups",
column="id",
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ab0bbe4bd3..908cbc79e3 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -91,12 +91,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
def lock_table(self, txn, table: str) -> None:
...
- @abc.abstractmethod
- def get_next_state_group_id(self, txn) -> int:
- """Returns an int that can be used as a new state_group ID
- """
- ...
-
@property
@abc.abstractmethod
def server_version(self) -> str:
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 6c7d08a6f2..ff39281f85 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -92,7 +92,7 @@ class PostgresEngine(BaseDatabaseEngine):
errors.append(" - 'COLLATE' is set to %r. Should be 'C'" % (collation,))
if ctype != "C":
- errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (collation,))
+ errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (ctype,))
if errors:
raise IncorrectDatabaseSetup(
@@ -154,12 +154,6 @@ class PostgresEngine(BaseDatabaseEngine):
def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
- def get_next_state_group_id(self, txn):
- """Returns an int that can be used as a new state_group ID
- """
- txn.execute("SELECT nextval('state_group_id_seq')")
- return txn.fetchone()[0]
-
@property
def server_version(self):
"""Returns a string giving the server version. For example: '8.1.5'
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 215a949442..8a0f8c89d1 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -96,19 +96,6 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
def lock_table(self, txn, table):
return
- def get_next_state_group_id(self, txn):
- """Returns an int that can be used as a new state_group ID
- """
- # We do application locking here since if we're using sqlite then
- # we are a single process synapse.
- with self._current_state_group_id_lock:
- if self._current_state_group_id is None:
- txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
- self._current_state_group_id = txn.fetchone()[0]
-
- self._current_state_group_id += 1
- return self._current_state_group_id
-
@property
def server_version(self):
"""Gets a string giving the server version. For example: '3.22.0'
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 4769b21529..afd10f7bae 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -22,6 +22,6 @@ logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True)
-class FetchKeyResult(object):
+class FetchKeyResult:
verify_key = attr.ib() # VerifyKey: the key itself
valid_until_ts = attr.ib() # int: how long we can use this key for
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index f159400a87..dbaeef91dd 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -20,21 +20,17 @@ import logging
from collections import deque, namedtuple
from typing import Iterable, List, Optional, Set, Tuple
-from six import iteritems
-from six.moves import range
-
from prometheus_client import Counter, Histogram
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.events import FrozenEvent
+from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.state import StateResolutionStore
-from synapse.storage.data_stores import DataStores
-from synapse.storage.data_stores.main.events import DeltaState
+from synapse.storage.databases import Databases
+from synapse.storage.databases.main.events import DeltaState
from synapse.types import StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -73,7 +69,7 @@ stale_forward_extremities_counter = Histogram(
)
-class _EventPeristenceQueue(object):
+class _EventPeristenceQueue:
"""Queues up events so that they can be persisted in bulk with only one
concurrent transaction per room.
"""
@@ -176,14 +172,14 @@ class _EventPeristenceQueue(object):
pass
-class EventsPersistenceStorage(object):
+class EventsPersistenceStorage:
"""High level interface for handling persisting newly received events.
Takes care of batching up events by room, and calculating the necessary
current state and forward extremity changes.
"""
- def __init__(self, hs, stores: DataStores):
+ def __init__(self, hs, stores: Databases):
# We ultimately want to split out the state store from the main store,
# so we use separate variables here even though they point to the same
# store for now.
@@ -196,12 +192,11 @@ class EventsPersistenceStorage(object):
self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler()
- @defer.inlineCallbacks
- def persist_events(
+ async def persist_events(
self,
- events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
- ):
+ ) -> int:
"""
Write events to the database
Args:
@@ -211,14 +206,14 @@ class EventsPersistenceStorage(object):
which might update the current state etc.
Returns:
- Deferred[int]: the stream ordering of the latest persisted event
+ the stream ordering of the latest persisted event
"""
partitioned = {}
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = []
- for room_id, evs_ctxs in iteritems(partitioned):
+ for room_id, evs_ctxs in partitioned.items():
d = self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, backfilled=backfilled
)
@@ -227,22 +222,19 @@ class EventsPersistenceStorage(object):
for room_id in partitioned:
self._maybe_start_persisting(room_id)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
- max_persisted_id = yield self.main_store.get_current_events_token()
+ return self.main_store.get_current_events_token()
- return max_persisted_id
-
- @defer.inlineCallbacks
- def persist_event(
- self, event: FrozenEvent, context: EventContext, backfilled: bool = False
- ):
+ async def persist_event(
+ self, event: EventBase, context: EventContext, backfilled: bool = False
+ ) -> Tuple[int, int]:
"""
Returns:
- Deferred[Tuple[int, int]]: the stream ordering of ``event``,
- and the stream ordering of the latest persisted event
+ The stream ordering of `event`, and the stream ordering of the
+ latest persisted event
"""
deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], backfilled=backfilled
@@ -250,9 +242,9 @@ class EventsPersistenceStorage(object):
self._maybe_start_persisting(event.room_id)
- yield make_deferred_yieldable(deferred)
+ await make_deferred_yieldable(deferred)
- max_persisted_id = yield self.main_store.get_current_events_token()
+ max_persisted_id = self.main_store.get_current_events_token()
return (event.internal_metadata.stream_ordering, max_persisted_id)
def _maybe_start_persisting(self, room_id: str):
@@ -266,7 +258,7 @@ class EventsPersistenceStorage(object):
async def _persist_events(
self,
- events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
):
"""Calculates the change to current state and forward extremities, and
@@ -319,7 +311,7 @@ class EventsPersistenceStorage(object):
(event, context)
)
- for room_id, ev_ctx_rm in iteritems(events_by_room):
+ for room_id, ev_ctx_rm in events_by_room.items():
latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
room_id
)
@@ -443,7 +435,7 @@ class EventsPersistenceStorage(object):
async def _calculate_new_extremities(
self,
room_id: str,
- event_contexts: List[Tuple[FrozenEvent, EventContext]],
+ event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: List[str],
):
"""Calculates the new forward extremities for a room given events to
@@ -501,7 +493,7 @@ class EventsPersistenceStorage(object):
async def _get_new_state_after_events(
self,
room_id: str,
- events_context: List[Tuple[FrozenEvent, EventContext]],
+ events_context: List[Tuple[EventBase, EventContext]],
old_latest_event_ids: Iterable[str],
new_latest_event_ids: Iterable[str],
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
@@ -651,6 +643,10 @@ class EventsPersistenceStorage(object):
room_version = await self.main_store.get_room_version_id(room_id)
logger.debug("calling resolve_state_groups from preserve_events")
+
+ # Avoid a circular import.
+ from synapse.state import StateResolutionStore
+
res = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
@@ -674,7 +670,7 @@ class EventsPersistenceStorage(object):
to_insert = {
key: ev_id
- for key, ev_id in iteritems(current_state)
+ for key, ev_id in current_state.items()
if ev_id != existing_state.get(key)
}
@@ -683,7 +679,7 @@ class EventsPersistenceStorage(object):
async def _is_server_still_joined(
self,
room_id: str,
- ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
+ ev_ctx_rm: List[Tuple[EventBase, EventContext]],
delta: DeltaState,
current_state: Optional[StateMap[str]],
potentially_left_users: Set[str],
@@ -786,9 +782,3 @@ class EventsPersistenceStorage(object):
for user_id in left_users:
await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
-
- async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
- """Mark the invite has having been rejected even though we failed to
- create a leave event for it.
- """
- return await self.persist_events_store.locally_reject_invite(user_id, room_id)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 9cc3b51fe6..ee60e2a718 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -47,8 +47,24 @@ class UpgradeDatabaseException(PrepareDatabaseException):
pass
-def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]):
- """Prepares a database for usage. Will either create all necessary tables
+OUTDATED_SCHEMA_ON_WORKER_ERROR = (
+ "Expected database schema version %i but got %i: run the main synapse process to "
+ "upgrade the database schema before starting worker processes."
+)
+
+EMPTY_DATABASE_ON_WORKER_ERROR = (
+ "Uninitialised database: run the main synapse process to prepare the database "
+ "schema before starting worker processes."
+)
+
+UNAPPLIED_DELTA_ON_WORKER_ERROR = (
+ "Database schema delta %s has not been applied: run the main synapse process to "
+ "upgrade the database schema before starting worker processes."
+)
+
+
+def prepare_database(db_conn, database_engine, config, databases=["main", "state"]):
+ """Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.
If `config` is None then prepare_database will assert that no upgrade is
@@ -60,37 +76,57 @@ def prepare_database(db_conn, database_engine, config, data_stores=["main", "sta
config (synapse.config.homeserver.HomeServerConfig|None):
application config, or None if we are connecting to an existing
database which we expect to be configured already
- data_stores (list[str]): The name of the data stores that will be used
- with this database. Defaults to all data stores.
+ databases (list[str]): The name of the databases that will be used
+ with this physical database. Defaults to all databases.
"""
try:
cur = db_conn.cursor()
+
+ logger.info("%r: Checking existing schema version", databases)
version_info = _get_or_create_schema_state(cur, database_engine)
if version_info:
user_version, delta_files, upgraded = version_info
+ logger.info(
+ "%r: Existing schema is %i (+%i deltas)",
+ databases,
+ user_version,
+ len(delta_files),
+ )
+ # config should only be None when we are preparing an in-memory SQLite db,
+ # which should be empty.
if config is None:
- if user_version != SCHEMA_VERSION:
- # If we don't pass in a config file then we are expecting to
- # have already upgraded the DB.
- raise UpgradeDatabaseException(
- "Expected database schema version %i but got %i"
- % (SCHEMA_VERSION, user_version)
- )
- else:
- _upgrade_existing_database(
- cur,
- user_version,
- delta_files,
- upgraded,
- database_engine,
- config,
- data_stores=data_stores,
+ raise ValueError(
+ "config==None in prepare_database, but databse is not empty"
+ )
+
+ # if it's a worker app, refuse to upgrade the database, to avoid multiple
+ # workers doing it at once.
+ if config.worker_app is not None and user_version != SCHEMA_VERSION:
+ raise UpgradeDatabaseException(
+ OUTDATED_SCHEMA_ON_WORKER_ERROR % (SCHEMA_VERSION, user_version)
)
+
+ _upgrade_existing_database(
+ cur,
+ user_version,
+ delta_files,
+ upgraded,
+ database_engine,
+ config,
+ databases=databases,
+ )
else:
- _setup_new_database(cur, database_engine, data_stores=data_stores)
+ logger.info("%r: Initialising new database", databases)
+
+ # if it's a worker app, refuse to upgrade the database, to avoid multiple
+ # workers doing it at once.
+ if config and config.worker_app is not None:
+ raise UpgradeDatabaseException(EMPTY_DATABASE_ON_WORKER_ERROR)
+
+ _setup_new_database(cur, database_engine, databases=databases)
# check if any of our configured dynamic modules want a database
if config is not None:
@@ -103,9 +139,9 @@ def prepare_database(db_conn, database_engine, config, data_stores=["main", "sta
raise
-def _setup_new_database(cur, database_engine, data_stores):
- """Sets up the database by finding a base set of "full schemas" and then
- applying any necessary deltas, including schemas from the given data
+def _setup_new_database(cur, database_engine, databases):
+ """Sets up the physical database by finding a base set of "full schemas" and
+ then applying any necessary deltas, including schemas from the given data
stores.
The "full_schemas" directory has subdirectories named after versions. This
@@ -138,8 +174,8 @@ def _setup_new_database(cur, database_engine, data_stores):
Args:
cur (Cursor): a database cursor
database_engine (DatabaseEngine)
- data_stores (list[str]): The names of the data stores to instantiate
- on the given database.
+ databases (list[str]): The names of the databases to instantiate
+ on the given physical database.
"""
# We're about to set up a brand new database so we check that its
@@ -176,13 +212,13 @@ def _setup_new_database(cur, database_engine, data_stores):
directories.extend(
os.path.join(
dir_path,
- "data_stores",
- data_store,
+ "databases",
+ database,
"schema",
"full_schemas",
str(max_current_ver),
)
- for data_store in data_stores
+ for database in databases
)
directory_entries = []
@@ -219,7 +255,7 @@ def _setup_new_database(cur, database_engine, data_stores):
upgraded=False,
database_engine=database_engine,
config=None,
- data_stores=data_stores,
+ databases=databases,
is_empty=True,
)
@@ -231,10 +267,10 @@ def _upgrade_existing_database(
upgraded,
database_engine,
config,
- data_stores,
+ databases,
is_empty=False,
):
- """Upgrades an existing database.
+ """Upgrades an existing physical database.
Delta files can either be SQL stored in *.sql files, or python modules
in *.py.
@@ -285,8 +321,8 @@ def _upgrade_existing_database(
config (synapse.config.homeserver.HomeServerConfig|None):
None if we are initialising a blank database, otherwise the application
config
- data_stores (list[str]): The names of the data stores to instantiate
- on the given database.
+ databases (list[str]): The names of the databases to instantiate
+ on the given physical database.
is_empty (bool): Is this a blank database? I.e. do we need to run the
upgrade portions of the delta scripts.
"""
@@ -295,6 +331,8 @@ def _upgrade_existing_database(
else:
assert config
+ is_worker = config and config.worker_app is not None
+
if current_version > SCHEMA_VERSION:
raise ValueError(
"Cannot use this database as it is too "
@@ -303,8 +341,8 @@ def _upgrade_existing_database(
# some of the deltas assume that config.server_name is set correctly, so now
# is a good time to run the sanity check.
- if not is_empty and "main" in data_stores:
- from synapse.storage.data_stores.main import check_database_before_upgrade
+ if not is_empty and "main" in databases:
+ from synapse.storage.databases.main import check_database_before_upgrade
check_database_before_upgrade(cur, database_engine, config)
@@ -322,7 +360,7 @@ def _upgrade_existing_database(
specific_engine_extensions = (".sqlite", ".postgres")
for v in range(start_ver, SCHEMA_VERSION + 1):
- logger.info("Upgrading schema to v%d", v)
+ logger.info("Applying schema deltas for v%d", v)
# We need to search both the global and per data store schema
# directories for schema updates.
@@ -330,11 +368,9 @@ def _upgrade_existing_database(
# First we find the directories to search in
delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
directories = [delta_dir]
- for data_store in data_stores:
+ for database in databases:
directories.append(
- os.path.join(
- dir_path, "data_stores", data_store, "schema", "delta", str(v)
- )
+ os.path.join(dir_path, "databases", database, "schema", "delta", str(v))
)
# Used to check if we have any duplicate file names
@@ -384,9 +420,15 @@ def _upgrade_existing_database(
continue
root_name, ext = os.path.splitext(file_name)
+
if ext == ".py":
# This is a python upgrade module. We need to import into some
# package and then execute its `run_upgrade` function.
+ if is_worker:
+ raise PrepareDatabaseException(
+ UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path
+ )
+
module_name = "synapse.storage.v%d_%s" % (v, root_name)
with open(absolute_path) as python_file:
module = imp.load_source(module_name, absolute_path, python_file)
@@ -401,10 +443,18 @@ def _upgrade_existing_database(
continue
elif ext == ".sql":
# A plain old .sql file, just read and execute it
+ if is_worker:
+ raise PrepareDatabaseException(
+ UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path
+ )
logger.info("Applying schema %s", relative_path)
executescript(cur, absolute_path)
elif ext == specific_engine_extension and root_name.endswith(".sql"):
# A .sql file specific to our engine; just read and execute it
+ if is_worker:
+ raise PrepareDatabaseException(
+ UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path
+ )
logger.info("Applying engine-specific schema %s", relative_path)
executescript(cur, absolute_path)
elif ext in specific_engine_extensions and root_name.endswith(".sql"):
@@ -434,6 +484,8 @@ def _upgrade_existing_database(
(v, True),
)
+ logger.info("Schema now up to date")
+
def _apply_module_schemas(txn, database_engine, config):
"""Apply the module schemas for the dynamic modules, if any
@@ -571,7 +623,7 @@ def _get_or_create_schema_state(txn, database_engine):
@attr.s()
-class _DirectoryListing(object):
+class _DirectoryListing:
"""Helper class to store schema file name and the
absolute path to it.
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
deleted file mode 100644
index 18a462f0ee..0000000000
--- a/synapse/storage/presence.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from collections import namedtuple
-
-from synapse.api.constants import PresenceState
-
-
-class UserPresenceState(
- namedtuple(
- "UserPresenceState",
- (
- "user_id",
- "state",
- "last_active_ts",
- "last_federation_update_ts",
- "last_user_sync_ts",
- "status_msg",
- "currently_active",
- ),
- )
-):
- """Represents the current presence state of the user.
-
- user_id (str)
- last_active (int): Time in msec that the user last interacted with server.
- last_federation_update (int): Time in msec since either a) we sent a presence
- update to other servers or b) we received a presence update, depending
- on if is a local user or not.
- last_user_sync (int): Time in msec that the user last *completed* a sync
- (or event stream).
- status_msg (str): User set status message.
- """
-
- def as_dict(self):
- return dict(self._asdict())
-
- @staticmethod
- def from_dict(d):
- return UserPresenceState(**d)
-
- def copy_and_replace(self, **kwargs):
- return self._replace(**kwargs)
-
- @classmethod
- def default(cls, user_id):
- """Returns a default presence state.
- """
- return cls(
- user_id=user_id,
- state=PresenceState.OFFLINE,
- last_active_ts=0,
- last_federation_update_ts=0,
- last_user_sync_ts=0,
- status_msg=None,
- currently_active=False,
- )
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index fdc0abf5cf..bfa0a9fd06 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -15,62 +15,60 @@
import itertools
import logging
-
-from twisted.internet import defer
+from typing import Set
logger = logging.getLogger(__name__)
-class PurgeEventsStorage(object):
+class PurgeEventsStorage:
"""High level interface for purging rooms and event history.
"""
def __init__(self, hs, stores):
self.stores = stores
- @defer.inlineCallbacks
- def purge_room(self, room_id: str):
+ async def purge_room(self, room_id: str):
"""Deletes all record of a room
"""
- state_groups_to_delete = yield self.stores.main.purge_room(room_id)
- yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
+ state_groups_to_delete = await self.stores.main.purge_room(room_id)
+ await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
- @defer.inlineCallbacks
- def purge_history(self, room_id, token, delete_local_events):
+ async def purge_history(
+ self, room_id: str, token: str, delete_local_events: bool
+ ) -> None:
"""Deletes room history before a certain point
Args:
- room_id (str):
+ room_id: The room ID
- token (str): A topological token to delete events before
+ token: A topological token to delete events before
- delete_local_events (bool):
+ delete_local_events:
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
"""
- state_groups = yield self.stores.main.purge_history(
+ state_groups = await self.stores.main.purge_history(
room_id, token, delete_local_events
)
logger.info("[purge] finding state groups that can be deleted")
- sg_to_delete = yield self._find_unreferenced_groups(state_groups)
+ sg_to_delete = await self._find_unreferenced_groups(state_groups)
- yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
+ await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
- @defer.inlineCallbacks
- def _find_unreferenced_groups(self, state_groups):
+ async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
"""Used when purging history to figure out which state groups can be
deleted.
Args:
- state_groups (set[int]): Set of state groups referenced by events
+ state_groups: Set of state groups referenced by events
that are going to be deleted.
Returns:
- Deferred[set[int]] The set of state groups that can be deleted.
+ The set of state groups that can be deleted.
"""
# Graph of state group -> previous group
graph = {}
@@ -93,7 +91,7 @@ class PurgeEventsStorage(object):
current_search = set(itertools.islice(next_to_search, 100))
next_to_search -= current_search
- referenced = yield self.stores.main.get_referenced_state_groups(
+ referenced = await self.stores.main.get_referenced_state_groups(
current_search
)
referenced_groups |= referenced
@@ -102,7 +100,7 @@ class PurgeEventsStorage(object):
# groups that are referenced.
current_search -= referenced
- edges = yield self.stores.state.get_previous_state_groups(current_search)
+ edges = await self.stores.state.get_previous_state_groups(current_search)
prevs = set(edges.values())
# We don't bother re-handling groups we've already seen
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index d471ec9860..d30e3f11e7 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
@attr.s
-class PaginationChunk(object):
+class PaginationChunk:
"""Returned by relation pagination APIs.
Attributes:
@@ -51,7 +51,7 @@ class PaginationChunk(object):
@attr.s(frozen=True, slots=True)
-class RelationPaginationToken(object):
+class RelationPaginationToken:
"""Pagination token for relation pagination API.
As the results are in topological order, we can use the
@@ -82,7 +82,7 @@ class RelationPaginationToken(object):
@attr.s(frozen=True, slots=True)
-class AggregationPaginationToken(object):
+class AggregationPaginationToken:
"""Pagination token for relation aggregation pagination API.
As the results are order by count and then MAX(stream_ordering) of the
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index c522c80922..8f68d968f0 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,15 +14,12 @@
# limitations under the License.
import logging
-from typing import Iterable, List, TypeVar
-
-from six import iteritems, itervalues
+from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
import attr
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
+from synapse.events import EventBase
from synapse.types import StateMap
logger = logging.getLogger(__name__)
@@ -32,58 +29,57 @@ T = TypeVar("T")
@attr.s(slots=True)
-class StateFilter(object):
+class StateFilter:
"""A filter used when querying for state.
Attributes:
- types (dict[str, set[str]|None]): Map from type to set of state keys (or
- None). This specifies which state_keys for the given type to fetch
- from the DB. If None then all events with that type are fetched. If
- the set is empty then no events with that type are fetched.
- include_others (bool): Whether to fetch events with types that do not
+ types: Map from type to set of state keys (or None). This specifies
+ which state_keys for the given type to fetch from the DB. If None
+ then all events with that type are fetched. If the set is empty
+ then no events with that type are fetched.
+ include_others: Whether to fetch events with types that do not
appear in `types`.
"""
- types = attr.ib()
- include_others = attr.ib(default=False)
+ types = attr.ib(type=Dict[str, Optional[Set[str]]])
+ include_others = attr.ib(default=False, type=bool)
def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
- self.types = {k: v for k, v in iteritems(self.types) if v is not None}
+ self.types = {k: v for k, v in self.types.items() if v is not None}
@staticmethod
- def all():
+ def all() -> "StateFilter":
"""Creates a filter that fetches everything.
Returns:
- StateFilter
+ The new state filter.
"""
return StateFilter(types={}, include_others=True)
@staticmethod
- def none():
+ def none() -> "StateFilter":
"""Creates a filter that fetches nothing.
Returns:
- StateFilter
+ The new state filter.
"""
return StateFilter(types={}, include_others=False)
@staticmethod
- def from_types(types):
+ def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
"""Creates a filter that only fetches the given types
Args:
- types (Iterable[tuple[str, str|None]]): A list of type and state
- keys to fetch. A state_key of None fetches everything for
- that type
+ types: A list of type and state keys to fetch. A state_key of None
+ fetches everything for that type
Returns:
- StateFilter
+ The new state filter.
"""
- type_dict = {}
+ type_dict = {} # type: Dict[str, Optional[Set[str]]]
for typ, s in types:
if typ in type_dict:
if type_dict[typ] is None:
@@ -93,24 +89,24 @@ class StateFilter(object):
type_dict[typ] = None
continue
- type_dict.setdefault(typ, set()).add(s)
+ type_dict.setdefault(typ, set()).add(s) # type: ignore
return StateFilter(types=type_dict)
@staticmethod
- def from_lazy_load_member_list(members):
+ def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
"""Creates a filter that returns all non-member events, plus the member
events for the given users
Args:
- members (iterable[str]): Set of user IDs
+ members: Set of user IDs
Returns:
- StateFilter
+ The new state filter
"""
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
- def return_expanded(self):
+ def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed
(except for memberships). The returned filter is a superset of the
current one, i.e. anything that passes the current filter will pass
@@ -132,7 +128,7 @@ class StateFilter(object):
return all non-member events
Returns:
- StateFilter
+ The new state filter.
"""
if self.is_full():
@@ -150,7 +146,7 @@ class StateFilter(object):
has_non_member_wildcard = self.include_others or any(
state_keys is None
- for t, state_keys in iteritems(self.types)
+ for t, state_keys in self.types.items()
if t != EventTypes.Member
)
@@ -169,7 +165,7 @@ class StateFilter(object):
include_others=True,
)
- def make_sql_filter_clause(self):
+ def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
"""Converts the filter to an SQL clause.
For example:
@@ -181,13 +177,12 @@ class StateFilter(object):
Returns:
- tuple[str, list]: The SQL string (may be empty) and arguments. An
- empty SQL string is returned when the filter matches everything
- (i.e. is "full").
+ The SQL string (may be empty) and arguments. An empty SQL string is
+ returned when the filter matches everything (i.e. is "full").
"""
where_clause = ""
- where_args = []
+ where_args = [] # type: List[str]
if self.is_full():
return where_clause, where_args
@@ -199,7 +194,7 @@ class StateFilter(object):
# First we build up a lost of clauses for each type/state_key combo
clauses = []
- for etype, state_keys in iteritems(self.types):
+ for etype, state_keys in self.types.items():
if state_keys is None:
clauses.append("(type = ?)")
where_args.append(etype)
@@ -223,7 +218,7 @@ class StateFilter(object):
return where_clause, where_args
- def max_entries_returned(self):
+ def max_entries_returned(self) -> Optional[int]:
"""Returns the maximum number of entries this filter will return if
known, otherwise returns None.
@@ -251,7 +246,7 @@ class StateFilter(object):
return dict(state_dict)
filtered_state = {}
- for k, v in iteritems(state_dict):
+ for k, v in state_dict.items():
typ, state_key = k
if typ in self.types:
state_keys = self.types[typ]
@@ -262,42 +257,42 @@ class StateFilter(object):
return filtered_state
- def is_full(self):
+ def is_full(self) -> bool:
"""Whether this filter fetches everything or not
Returns:
- bool
+ True if the filter fetches everything.
"""
return self.include_others and not self.types
- def has_wildcards(self):
+ def has_wildcards(self) -> bool:
"""Whether the filter includes wildcards or is attempting to fetch
specific state.
Returns:
- bool
+ True if the filter includes wildcards.
"""
return self.include_others or any(
- state_keys is None for state_keys in itervalues(self.types)
+ state_keys is None for state_keys in self.types.values()
)
- def concrete_types(self):
+ def concrete_types(self) -> List[Tuple[str, str]]:
"""Returns a list of concrete type/state_keys (i.e. not None) that
will be fetched. This will be a complete list if `has_wildcards`
returns False, but otherwise will be a subset (or even empty).
Returns:
- list[tuple[str,str]]
+ A list of type/state_keys tuples.
"""
return [
(t, s)
- for t, state_keys in iteritems(self.types)
+ for t, state_keys in self.types.items()
if state_keys is not None
for s in state_keys
]
- def get_member_split(self):
+ def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
"""Return the filter split into two: one which assumes it's exclusively
matching against member state, and one which assumes it's matching
against non member state.
@@ -309,7 +304,7 @@ class StateFilter(object):
state caches).
Returns:
- tuple[StateFilter, StateFilter]: The member and non member filters
+ The member and non member filters
"""
if EventTypes.Member in self.types:
@@ -324,84 +319,91 @@ class StateFilter(object):
member_filter = StateFilter.none()
non_member_filter = StateFilter(
- types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member},
+ types={k: v for k, v in self.types.items() if k != EventTypes.Member},
include_others=self.include_others,
)
return member_filter, non_member_filter
-class StateGroupStorage(object):
+class StateGroupStorage:
"""High level interface to fetching state for event.
"""
def __init__(self, hs, stores):
self.stores = stores
- def get_state_group_delta(self, state_group: int):
+ async def get_state_group_delta(self, state_group: int):
"""Given a state group try to return a previous group and a delta between
the old and the new.
+ Args:
+ state_group: The state group used to retrieve state deltas.
+
Returns:
- Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
+ Tuple[Optional[int], Optional[StateMap[str]]]:
(prev_group, delta_ids)
"""
- return self.stores.state.get_state_group_delta(state_group)
+ return await self.stores.state.get_state_group_delta(state_group)
- @defer.inlineCallbacks
- def get_state_groups_ids(self, _room_id, event_ids):
+ async def get_state_groups_ids(
+ self, _room_id: str, event_ids: Iterable[str]
+ ) -> Dict[int, StateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
- _room_id (str): id of the room for these events
- event_ids (iterable[str]): ids of the events
+ _room_id: id of the room for these events
+ event_ids: ids of the events
Returns:
- Deferred[dict[int, StateMap[str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
return {}
- event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+ event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
- groups = set(itervalues(event_to_groups))
- group_to_state = yield self.stores.state._get_state_for_groups(groups)
+ groups = set(event_to_groups.values())
+ group_to_state = await self.stores.state._get_state_for_groups(groups)
return group_to_state
- @defer.inlineCallbacks
- def get_state_ids_for_group(self, state_group):
+ async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
"""Get the event IDs of all the state in the given state group
Args:
- state_group (int)
+ state_group: A state group for which we want to get the state IDs.
Returns:
- Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
+ Resolves to a map of (type, state_key) -> event_id
"""
- group_to_state = yield self._get_state_for_groups((state_group,))
+ group_to_state = await self._get_state_for_groups((state_group,))
return group_to_state[state_group]
- @defer.inlineCallbacks
- def get_state_groups(self, room_id, event_ids):
+ async def get_state_groups(
+ self, room_id: str, event_ids: Iterable[str]
+ ) -> Dict[int, List[EventBase]]:
""" Get the state groups for the given list of event_ids
+
+ Args:
+ room_id: ID of the room for these events.
+ event_ids: The event IDs to retrieve state for.
+
Returns:
- Deferred[dict[int, list[EventBase]]]:
- dict of state_group_id -> list of state events.
+ dict of state_group_id -> list of state events.
"""
if not event_ids:
return {}
- group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
+ group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
- state_event_map = yield self.stores.main.get_events(
+ state_event_map = await self.stores.main.get_events(
[
ev_id
- for group_ids in itervalues(group_to_ids)
- for ev_id in itervalues(group_ids)
+ for group_ids in group_to_ids.values()
+ for ev_id in group_ids.values()
],
get_prev_content=False,
)
@@ -409,15 +411,15 @@ class StateGroupStorage(object):
return {
group: [
state_event_map[v]
- for v in itervalues(event_id_map)
+ for v in event_id_map.values()
if v in state_event_map
]
- for group, event_id_map in iteritems(group_to_ids)
+ for group, event_id_map in group_to_ids.items()
}
def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
- ):
+ ) -> Awaitable[Dict[int, StateMap[str]]]:
"""Returns the state groups for a given set of groups, filtering on
types of state events.
@@ -425,140 +427,148 @@ class StateGroupStorage(object):
groups: list of state group IDs to query
state_filter: The state filter used to fetch state
from the database.
+
Returns:
- Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
+ Dict of state group to state map.
"""
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
- @defer.inlineCallbacks
- def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
+ async def get_state_for_events(
+ self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
+ ):
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
+
Args:
- event_ids (list[string])
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ event_ids: The events to fetch the state of.
+ state_filter: The state filter used to fetch state.
+
Returns:
- deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
+ A dict of (event_id) -> (type, state_key) -> [state_events]
"""
- event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+ event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
- groups = set(itervalues(event_to_groups))
- group_to_state = yield self.stores.state._get_state_for_groups(
+ groups = set(event_to_groups.values())
+ group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter
)
- state_event_map = yield self.stores.main.get_events(
- [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
+ state_event_map = await self.stores.main.get_events(
+ [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False,
)
event_to_state = {
event_id: {
k: state_event_map[v]
- for k, v in iteritems(group_to_state[group])
+ for k, v in group_to_state[group].items()
if v in state_event_map
}
- for event_id, group in iteritems(event_to_groups)
+ for event_id, group in event_to_groups.items()
}
return {event: event_to_state[event] for event in event_ids}
- @defer.inlineCallbacks
- def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
+ async def get_state_ids_for_events(
+ self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
+ ):
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
Args:
- event_ids(list(str)): events whose state should be returned
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ event_ids: events whose state should be returned
+ state_filter: The state filter used to fetch state from the database.
Returns:
- A deferred dict from event_id -> (type, state_key) -> event_id
+ A dict from event_id -> (type, state_key) -> event_id
"""
- event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+ event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
- groups = set(itervalues(event_to_groups))
- group_to_state = yield self.stores.state._get_state_for_groups(
+ groups = set(event_to_groups.values())
+ group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter
)
event_to_state = {
event_id: group_to_state[group]
- for event_id, group in iteritems(event_to_groups)
+ for event_id, group in event_to_groups.items()
}
return {event: event_to_state[event] for event in event_ids}
- @defer.inlineCallbacks
- def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
+ async def get_state_for_event(
+ self, event_id: str, state_filter: StateFilter = StateFilter.all()
+ ):
"""
Get the state dict corresponding to a particular event
Args:
- event_id(str): event whose state should be returned
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ event_id: event whose state should be returned
+ state_filter: The state filter used to fetch state from the database.
Returns:
- A deferred dict from (type, state_key) -> state_event
+ A dict from (type, state_key) -> state_event
"""
- state_map = yield self.get_state_for_events([event_id], state_filter)
+ state_map = await self.get_state_for_events([event_id], state_filter)
return state_map[event_id]
- @defer.inlineCallbacks
- def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
+ async def get_state_ids_for_event(
+ self, event_id: str, state_filter: StateFilter = StateFilter.all()
+ ):
"""
Get the state dict corresponding to a particular event
Args:
- event_id(str): event whose state should be returned
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ event_id: event whose state should be returned
+ state_filter: The state filter used to fetch state from the database.
Returns:
- A deferred dict from (type, state_key) -> state_event
+ A dict from (type, state_key) -> state_event
"""
- state_map = yield self.get_state_ids_for_events([event_id], state_filter)
+ state_map = await self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]
def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> Awaitable[Dict[int, StateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
Args:
- groups (iterable[int]): list of state groups for which we want
- to get the state.
- state_filter (StateFilter): The state filter used to fetch state
+ groups: list of state groups for which we want to get the state.
+ state_filter: The state filter used to fetch state.
from the database.
+
Returns:
- Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
+ Dict of state group to state map.
"""
return self.stores.state._get_state_for_groups(groups, state_filter)
- def store_state_group(
- self, event_id, room_id, prev_group, delta_ids, current_state_ids
- ):
+ async def store_state_group(
+ self,
+ event_id: str,
+ room_id: str,
+ prev_group: Optional[int],
+ delta_ids: Optional[dict],
+ current_state_ids: dict,
+ ) -> int:
"""Store a new set of state, returning a newly assigned state group.
Args:
- event_id (str): The event ID for which the state was calculated
- room_id (str)
- prev_group (int|None): A previous state group for the room, optional.
- delta_ids (dict|None): The delta between state at `prev_group` and
+ event_id: The event ID for which the state was calculated.
+ room_id: ID of the room for which the state was calculated.
+ prev_group: A previous state group for the room, optional.
+ delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
- current_state_ids (dict): The state to store. Map of (type, state_key)
+ current_state_ids: The state to store. Map of (type, state_key)
to event_id.
Returns:
- Deferred[int]: The state group ID
+ The state group ID
"""
- return self.stores.state.store_state_group(
+ return await self.stores.state.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index daff81c5ee..2d2b560e74 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,12 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Any, Iterable, Iterator, List, Tuple
from typing_extensions import Protocol
-
"""
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f89ce0bed2..b7eb4f8ac9 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -14,16 +14,21 @@
# limitations under the License.
import contextlib
+import heapq
+import logging
import threading
from collections import deque
-from typing import Dict, Set, Tuple
+from typing import Dict, List, Set
from typing_extensions import Deque
-from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.util.sequence import PostgresSequenceGenerator
+logger = logging.getLogger(__name__)
-class IdGenerator(object):
+
+class IdGenerator:
def __init__(self, db_conn, table, column):
self._lock = threading.Lock()
self._next_id = _load_current_id(db_conn, table, column)
@@ -46,6 +51,8 @@ def _load_current_id(db_conn, table, column, step=1):
Returns:
int
"""
+ # debug logging for https://github.com/matrix-org/synapse/issues/7968
+ logger.info("initialising stream generator for %s(%s)", table, column)
cur = db_conn.cursor()
if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
@@ -57,7 +64,7 @@ def _load_current_id(db_conn, table, column, step=1):
return (max if step > 0 else min)(current_id, step)
-class StreamIdGenerator(object):
+class StreamIdGenerator:
"""Used to generate new stream ids when persisting events while keeping
track of which transactions have been completed.
@@ -79,7 +86,7 @@ class StreamIdGenerator(object):
upwards, -1 to grow downwards.
Usage:
- with stream_id_gen.get_next() as stream_id:
+ with await stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@@ -94,10 +101,10 @@ class StreamIdGenerator(object):
)
self._unfinished_ids = deque() # type: Deque[int]
- def get_next(self):
+ async def get_next(self):
"""
Usage:
- with stream_id_gen.get_next() as stream_id:
+ with await stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -116,10 +123,10 @@ class StreamIdGenerator(object):
return manager()
- def get_next_mult(self, n):
+ async def get_next_mult(self, n):
"""
Usage:
- with stream_id_gen.get_next(n) as stream_ids:
+ with await stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -157,63 +164,13 @@ class StreamIdGenerator(object):
return self._current
+ def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
-class ChainedIdGenerator(object):
- """Used to generate new stream ids where the stream must be kept in sync
- with another stream. It generates pairs of IDs, the first element is an
- integer ID for this stream, the second element is the ID for the stream
- that this stream needs to be kept in sync with."""
-
- def __init__(self, chained_generator, db_conn, table, column):
- self.chained_generator = chained_generator
- self._table = table
- self._lock = threading.Lock()
- self._current_max = _load_current_id(db_conn, table, column)
- self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
-
- def get_next(self):
- """
- Usage:
- with stream_id_gen.get_next() as (stream_id, chained_id):
- # ... persist event ...
- """
- with self._lock:
- self._current_max += 1
- next_id = self._current_max
- chained_id = self.chained_generator.get_current_token()
-
- self._unfinished_ids.append((next_id, chained_id))
-
- @contextlib.contextmanager
- def manager():
- try:
- yield (next_id, chained_id)
- finally:
- with self._lock:
- self._unfinished_ids.remove((next_id, chained_id))
-
- return manager()
-
- def get_current_token(self):
- """Returns the maximum stream id such that all stream ids less than or
- equal to it have been successfully persisted.
- """
- with self._lock:
- if self._unfinished_ids:
- stream_id, chained_id = self._unfinished_ids[0]
- return stream_id - 1, chained_id
-
- return self._current_max, self.chained_generator.get_current_token()
-
- def advance(self, token: int):
- """Stub implementation for advancing the token when receiving updates
- over replication; raises an exception as this instance should be the
- only source of updates.
+ For streams with single writers this is equivalent to
+ `get_current_token`.
"""
-
- raise Exception(
- "Attempted to advance token on source for table %r", self._table
- )
+ return self.get_current_token()
class MultiWriterIdGenerator:
@@ -233,25 +190,32 @@ class MultiWriterIdGenerator:
id_column: Column that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
+ positive: Whether the IDs are positive (true) or negative (false).
+ When using negative IDs we go backwards from -1 to -2, -3, etc.
"""
def __init__(
self,
db_conn,
- db: Database,
+ db: DatabasePool,
instance_name: str,
table: str,
instance_column: str,
id_column: str,
sequence_name: str,
+ positive: bool = True,
):
self._db = db
self._instance_name = instance_name
- self._sequence_name = sequence_name
+ self._positive = positive
+ self._return_factor = 1 if positive else -1
# We lock as some functions may be called from DB threads.
self._lock = threading.Lock()
+ # Note: If we are a negative stream then we still store all the IDs as
+ # positive to make life easier for us, and simply negate the IDs when we
+ # return them.
self._current_positions = self._load_current_ids(
db_conn, table, instance_column, id_column
)
@@ -260,16 +224,38 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
+ # We track the max position where we know everything before has been
+ # persisted. This is done by a) looking at the min across all instances
+ # and b) noting that if we have seen a run of persisted positions
+ # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
+ #
+ # Note: There is no guarentee that the IDs generated by the sequence
+ # will be gapless; gaps can form when e.g. a transaction was rolled
+ # back. This means that sometimes we won't be able to skip forward the
+ # position even though everything has been persisted. However, since
+ # gaps should be relatively rare it's still worth doing the book keeping
+ # that allows us to skip forwards when there are gapless runs of
+ # positions.
+ self._persisted_upto_position = (
+ min(self._current_positions.values()) if self._current_positions else 0
+ )
+ self._known_persisted_positions = [] # type: List[int]
+
+ self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
) -> Dict[str, int]:
+ # If positive stream aggregate via MAX. For negative stream use MIN
+ # *and* negate the result to get a positive number.
sql = """
- SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+ SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
GROUP BY %(instance)s
""" % {
"instance": instance_column,
"id": id_column,
"table": table,
+ "agg": "MAX" if self._positive else "-MIN",
}
cur = db_conn.cursor()
@@ -282,10 +268,11 @@ class MultiWriterIdGenerator:
return current_positions
- def _load_next_id_txn(self, txn):
- txn.execute("SELECT nextval(?)", (self._sequence_name,))
- (next_id,) = txn.fetchone()
- return next_id
+ def _load_next_id_txn(self, txn) -> int:
+ return self._sequence_gen.get_next_id_txn(txn)
+
+ def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
+ return self._sequence_gen.get_next_mult_txn(txn, n)
async def get_next(self):
"""
@@ -298,20 +285,49 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
- assert self.get_current_token() < next_id
-
with self._lock:
+ assert self._current_positions.get(self._instance_name, 0) < next_id
+
self._unfinished_ids.add(next_id)
@contextlib.contextmanager
def manager():
try:
- yield next_id
+ # Multiply by the return factor so that the ID has correct sign.
+ yield self._return_factor * next_id
finally:
self._mark_id_as_finished(next_id)
return manager()
+ async def get_next_mult(self, n: int):
+ """
+ Usage:
+ with await stream_id_gen.get_next_mult(5) as stream_ids:
+ # ... persist events ...
+ """
+ next_ids = await self._db.runInteraction(
+ "_load_next_mult_id", self._load_next_mult_id_txn, n
+ )
+
+ # Assert the fetched ID is actually greater than any ID we've already
+ # seen. If not, then the sequence and table have got out of sync
+ # somehow.
+ with self._lock:
+ assert max(self._current_positions.values(), default=0) < min(next_ids)
+
+ self._unfinished_ids.update(next_ids)
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield [self._return_factor * i for i in next_ids]
+ finally:
+ for i in next_ids:
+ self._mark_id_as_finished(i)
+
+ return manager()
+
def get_next_txn(self, txn: LoggingTransaction):
"""
Usage:
@@ -328,7 +344,7 @@ class MultiWriterIdGenerator:
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
- return next_id
+ return self._return_factor * next_id
def _mark_id_as_finished(self, next_id: int):
"""The ID has finished being processed so we should advance the
@@ -344,33 +360,92 @@ class MultiWriterIdGenerator:
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, next_id)
- def get_current_token(self, instance_name: str = None) -> int:
- """Gets the current position of a named writer (defaults to current
- instance).
+ self._add_persisted_position(next_id)
- Returns 0 if we don't have a position for the named writer (likely due
- to it being a new writer).
+ def get_current_token(self) -> int:
+ """Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
"""
- if instance_name is None:
- instance_name = self._instance_name
+ # Currently we don't support this operation, as it's not obvious how to
+ # condense the stream positions of multiple writers into a single int.
+ raise NotImplementedError()
+
+ def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
+ """
with self._lock:
- return self._current_positions.get(instance_name, 0)
+ return self._return_factor * self._current_positions.get(instance_name, 0)
def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
"""
with self._lock:
- return dict(self._current_positions)
+ return {
+ name: self._return_factor * i
+ for name, i in self._current_positions.items()
+ }
def advance(self, instance_name: str, new_id: int):
"""Advance the postion of the named writer to the given ID, if greater
than existing entry.
"""
+ new_id *= self._return_factor
+
with self._lock:
self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0)
)
+
+ self._add_persisted_position(new_id)
+
+ def get_persisted_upto_position(self) -> int:
+ """Get the max position where all previous positions have been
+ persisted.
+
+ Note: In the worst case scenario this will be equal to the minimum
+ position across writers. This means that the returned position here can
+ lag if one writer doesn't write very often.
+ """
+
+ with self._lock:
+ return self._return_factor * self._persisted_upto_position
+
+ def _add_persisted_position(self, new_id: int):
+ """Record that we have persisted a position.
+
+ This is used to keep the `_current_positions` up to date.
+ """
+
+ # We require that the lock is locked by caller
+ assert self._lock.locked()
+
+ heapq.heappush(self._known_persisted_positions, new_id)
+
+ # We move the current min position up if the minimum current positions
+ # of all instances is higher (since by definition all positions less
+ # that that have been persisted).
+ min_curr = min(self._current_positions.values())
+ self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
+
+ # We now iterate through the seen positions, discarding those that are
+ # less than the current min positions, and incrementing the min position
+ # if its exactly one greater.
+ #
+ # This is also where we discard items from `_known_persisted_positions`
+ # (to ensure the list doesn't infinitely grow).
+ while self._known_persisted_positions:
+ if self._known_persisted_positions[0] <= self._persisted_upto_position:
+ heapq.heappop(self._known_persisted_positions)
+ elif (
+ self._known_persisted_positions[0] == self._persisted_upto_position + 1
+ ):
+ heapq.heappop(self._known_persisted_positions)
+ self._persisted_upto_position += 1
+ else:
+ # There was a gap in seen positions, so there is nothing more to
+ # do.
+ break
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
new file mode 100644
index 0000000000..ffc1894748
--- /dev/null
+++ b/synapse/storage/util/sequence.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import abc
+import threading
+from typing import Callable, List, Optional
+
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.types import Cursor
+
+
+class SequenceGenerator(metaclass=abc.ABCMeta):
+ """A class which generates a unique sequence of integers"""
+
+ @abc.abstractmethod
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ """Gets the next ID in the sequence"""
+ ...
+
+
+class PostgresSequenceGenerator(SequenceGenerator):
+ """An implementation of SequenceGenerator which uses a postgres sequence"""
+
+ def __init__(self, sequence_name: str):
+ self._sequence_name = sequence_name
+
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ txn.execute("SELECT nextval(?)", (self._sequence_name,))
+ return txn.fetchone()[0]
+
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ txn.execute(
+ "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
+ )
+ return [i for (i,) in txn]
+
+
+GetFirstCallbackType = Callable[[Cursor], int]
+
+
+class LocalSequenceGenerator(SequenceGenerator):
+ """An implementation of SequenceGenerator which uses local locking
+
+ This only works reliably if there are no other worker processes generating IDs at
+ the same time.
+ """
+
+ def __init__(self, get_first_callback: GetFirstCallbackType):
+ """
+ Args:
+ get_first_callback: a callback which is called on the first call to
+ get_next_id_txn; should return the curreent maximum id
+ """
+ # the callback. this is cleared after it is called, so that it can be GCed.
+ self._callback = get_first_callback # type: Optional[GetFirstCallbackType]
+
+ # The current max value, or None if we haven't looked in the DB yet.
+ self._current_max_id = None # type: Optional[int]
+ self._lock = threading.Lock()
+
+ def get_next_id_txn(self, txn: Cursor) -> int:
+ # We do application locking here since if we're using sqlite then
+ # we are a single process synapse.
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ self._current_max_id += 1
+ return self._current_max_id
+
+
+def build_sequence_generator(
+ database_engine: BaseDatabaseEngine,
+ get_first_callback: GetFirstCallbackType,
+ sequence_name: str,
+) -> SequenceGenerator:
+ """Get the best impl of SequenceGenerator available
+
+ This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on
+ sqlite.
+
+ Args:
+ database_engine: the database engine we are connected to
+ get_first_callback: a callback which gets the next sequence ID. Used if
+ we're on sqlite.
+ sequence_name: the name of a postgres sequence to use.
+ """
+ if isinstance(database_engine, PostgresEngine):
+ return PostgresSequenceGenerator(sequence_name)
+ else:
+ return LocalSequenceGenerator(get_first_callback)
|