diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 0623da9aa1..3056e64ff5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -21,7 +21,7 @@ from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
-from synapse.types import StreamToken, get_domain_from_id
+from synapse.types import get_domain_from_id
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -48,7 +48,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self,
stream_name: str,
instance_name: str,
- token: StreamToken,
+ token: int,
rows: Iterable[Any],
) -> None:
pass
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index b9a8ca997e..d64910aded 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,12 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
+from typing import (
+ TYPE_CHECKING,
+ AsyncContextManager,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ Optional,
+)
+
+import attr
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.types import Connection
from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import Clock, json_encoder
from . import engines
@@ -28,6 +38,45 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
+DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _BackgroundUpdateHandler:
+ """A handler for a given background update.
+
+ Attributes:
+ callback: The function to call to make progress on the background
+ update.
+ oneshot: Wether the update is likely to happen all in one go, ignoring
+ the supplied target duration, e.g. index creation. This is used by
+ the update controller to help correctly schedule the update.
+ """
+
+ callback: Callable[[JsonDict, int], Awaitable[int]]
+ oneshot: bool = False
+
+
+class _BackgroundUpdateContextManager:
+ BACKGROUND_UPDATE_INTERVAL_MS = 1000
+ BACKGROUND_UPDATE_DURATION_MS = 100
+
+ def __init__(self, sleep: bool, clock: Clock):
+ self._sleep = sleep
+ self._clock = clock
+
+ async def __aenter__(self) -> int:
+ if self._sleep:
+ await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
+
+ return self.BACKGROUND_UPDATE_DURATION_MS
+
+ async def __aexit__(self, *exc) -> None:
+ pass
+
+
class BackgroundUpdatePerformance:
"""Tracks the how long a background update is taking to update its items"""
@@ -82,22 +131,24 @@ class BackgroundUpdater:
process and autotuning the batch size.
"""
- MINIMUM_BACKGROUND_BATCH_SIZE = 100
+ MINIMUM_BACKGROUND_BATCH_SIZE = 1
DEFAULT_BACKGROUND_BATCH_SIZE = 100
- BACKGROUND_UPDATE_INTERVAL_MS = 1000
- BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock()
self.db_pool = database
+ self._database_name = database.name()
+
# if a background update is currently running, its name.
self._current_background_update: Optional[str] = None
+ self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
+ self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
+ self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
+
self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
- self._background_update_handlers: Dict[
- str, Callable[[JsonDict, int], Awaitable[int]]
- ] = {}
+ self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {}
self._all_done = False
# Whether we're currently running updates
@@ -107,6 +158,83 @@ class BackgroundUpdater:
# enable/disable background updates via the admin API.
self.enabled = True
+ def register_update_controller_callbacks(
+ self,
+ on_update: ON_UPDATE_CALLBACK,
+ default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+ min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
+ ) -> None:
+ """Register callbacks from a module for each hook."""
+ if self._on_update_callback is not None:
+ logger.warning(
+ "More than one module tried to register callbacks for controlling"
+ " background updates. Only the callbacks registered by the first module"
+ " (in order of appearance in Synapse's configuration file) that tried to"
+ " do so will be called."
+ )
+
+ return
+
+ self._on_update_callback = on_update
+
+ if default_batch_size is not None:
+ self._default_batch_size_callback = default_batch_size
+
+ if min_batch_size is not None:
+ self._min_batch_size_callback = min_batch_size
+
+ def _get_context_manager_for_update(
+ self,
+ sleep: bool,
+ update_name: str,
+ database_name: str,
+ oneshot: bool,
+ ) -> AsyncContextManager[int]:
+ """Get a context manager to run a background update with.
+
+ If a module has registered a `update_handler` callback, use the context manager
+ it returns.
+
+ Otherwise, returns a context manager that will return a default value, optionally
+ sleeping if needed.
+
+ Args:
+ sleep: Whether we can sleep between updates.
+ update_name: The name of the update.
+ database_name: The name of the database the update is being run on.
+ oneshot: Whether the update will complete all in one go, e.g. index creation.
+ In such cases the returned target duration is ignored.
+
+ Returns:
+ The target duration in milliseconds that the background update should run for.
+
+ Note: this is a *target*, and an iteration may take substantially longer or
+ shorter.
+ """
+ if self._on_update_callback is not None:
+ return self._on_update_callback(update_name, database_name, oneshot)
+
+ return _BackgroundUpdateContextManager(sleep, self._clock)
+
+ async def _default_batch_size(self, update_name: str, database_name: str) -> int:
+ """The batch size to use for the first iteration of a new background
+ update.
+ """
+ if self._default_batch_size_callback is not None:
+ return await self._default_batch_size_callback(update_name, database_name)
+
+ return self.DEFAULT_BACKGROUND_BATCH_SIZE
+
+ async def _min_batch_size(self, update_name: str, database_name: str) -> int:
+ """A lower bound on the batch size of a new background update.
+
+ Used to ensure that progress is always made. Must be greater than 0.
+ """
+ if self._min_batch_size_callback is not None:
+ return await self._min_batch_size_callback(update_name, database_name)
+
+ return self.MINIMUM_BACKGROUND_BATCH_SIZE
+
def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
"""Returns the current background update, if any."""
@@ -122,6 +250,8 @@ class BackgroundUpdater:
def start_doing_background_updates(self) -> None:
if self.enabled:
+ # if we start a new background update, not all updates are done.
+ self._all_done = False
run_as_background_process("background_updates", self.run_background_updates)
async def run_background_updates(self, sleep: bool = True) -> None:
@@ -133,13 +263,8 @@ class BackgroundUpdater:
try:
logger.info("Starting background schema updates")
while self.enabled:
- if sleep:
- await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
-
try:
- result = await self.do_next_background_update(
- self.BACKGROUND_UPDATE_DURATION_MS
- )
+ result = await self.do_next_background_update(sleep)
except Exception:
logger.exception("Error doing update")
else:
@@ -201,13 +326,15 @@ class BackgroundUpdater:
return not update_exists
- async def do_next_background_update(self, desired_duration_ms: float) -> bool:
+ async def do_next_background_update(self, sleep: bool = True) -> bool:
"""Does some amount of work on the next queued background update
Returns once some amount of work is done.
Args:
- desired_duration_ms: How long we want to spend updating.
+ sleep: Whether to limit how quickly we run background updates or
+ not.
+
Returns:
True if we have finished running all the background updates, otherwise False
"""
@@ -250,7 +377,19 @@ class BackgroundUpdater:
self._current_background_update = upd["update_name"]
- await self._do_background_update(desired_duration_ms)
+ # We have a background update to run, otherwise we would have returned
+ # early.
+ assert self._current_background_update is not None
+ update_info = self._background_update_handlers[self._current_background_update]
+
+ async with self._get_context_manager_for_update(
+ sleep=sleep,
+ update_name=self._current_background_update,
+ database_name=self._database_name,
+ oneshot=update_info.oneshot,
+ ) as desired_duration_ms:
+ await self._do_background_update(desired_duration_ms)
+
return False
async def _do_background_update(self, desired_duration_ms: float) -> int:
@@ -258,7 +397,7 @@ class BackgroundUpdater:
update_name = self._current_background_update
logger.info("Starting update batch on background update '%s'", update_name)
- update_handler = self._background_update_handlers[update_name]
+ update_handler = self._background_update_handlers[update_name].callback
performance = self._background_update_performance.get(update_name)
@@ -271,9 +410,14 @@ class BackgroundUpdater:
if items_per_ms is not None:
batch_size = int(desired_duration_ms * items_per_ms)
# Clamp the batch size so that we always make progress
- batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE)
+ batch_size = max(
+ batch_size,
+ await self._min_batch_size(update_name, self._database_name),
+ )
else:
- batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
+ batch_size = await self._default_batch_size(
+ update_name, self._database_name
+ )
progress_json = await self.db_pool.simple_select_one_onecol(
"background_updates",
@@ -292,6 +436,8 @@ class BackgroundUpdater:
duration_ms = time_stop - time_start
+ performance.update(items_updated, duration_ms)
+
logger.info(
"Running background update %r. Processed %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
@@ -304,8 +450,6 @@ class BackgroundUpdater:
batch_size,
)
- performance.update(items_updated, duration_ms)
-
return len(self._background_update_performance)
def register_background_update_handler(
@@ -329,7 +473,9 @@ class BackgroundUpdater:
update_name: The name of the update that this code handles.
update_handler: The function that does the update.
"""
- self._background_update_handlers[update_name] = update_handler
+ self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+ update_handler
+ )
def register_noop_background_update(self, update_name: str) -> None:
"""Register a noop handler for a background update.
@@ -451,7 +597,9 @@ class BackgroundUpdater:
await self._end_background_update(update_name)
return 1
- self.register_background_update_handler(update_name, updater)
+ self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+ updater, oneshot=True
+ )
async def _end_background_update(self, update_name: str) -> None:
"""Removes a completed background update task from the queue.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d4cab69ebf..0693d39006 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -188,7 +188,7 @@ class LoggingDatabaseConnection:
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
-_CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]]
+_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]
R = TypeVar("R")
@@ -235,7 +235,7 @@ class LoggingTransaction:
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks
- def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any):
+ def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
@@ -247,7 +247,7 @@ class LoggingTransaction:
self.after_callbacks.append((callback, args, kwargs))
def call_on_exception(
- self, callback: Callable[..., None], *args: Any, **kwargs: Any
+ self, callback: Callable[..., object], *args: Any, **kwargs: Any
):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 7c0f953365..ab8766c75b 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -599,6 +599,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
+ REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
@@ -614,14 +615,18 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
- self.db_pool.updates.register_background_update_handler(
- self.REMOVE_DELETED_DEVICES,
- self._remove_deleted_devices_from_device_inbox,
+ # Used to be a background update that deletes all device_inboxes for deleted
+ # devices.
+ self.db_pool.updates.register_noop_background_update(
+ self.REMOVE_DELETED_DEVICES
)
+ # Used to be a background update that deletes all device_inboxes for hidden
+ # devices.
+ self.db_pool.updates.register_noop_background_update(self.REMOVE_HIDDEN_DEVICES)
self.db_pool.updates.register_background_update_handler(
- self.REMOVE_HIDDEN_DEVICES,
- self._remove_hidden_devices_from_device_inbox,
+ self.REMOVE_DEAD_DEVICES_FROM_INBOX,
+ self._remove_dead_devices_from_device_inbox,
)
async def _background_drop_index_device_inbox(self, progress, batch_size):
@@ -636,171 +641,83 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
return 1
- async def _remove_deleted_devices_from_device_inbox(
- self, progress: JsonDict, batch_size: int
+ async def _remove_dead_devices_from_device_inbox(
+ self,
+ progress: JsonDict,
+ batch_size: int,
) -> int:
- """A background update that deletes all device_inboxes for deleted devices.
-
- This should only need to be run once (when users upgrade to v1.47.0)
+ """A background update to remove devices that were either deleted or hidden from
+ the device_inbox table.
Args:
- progress: JsonDict used to store progress of this background update
- batch_size: the maximum number of rows to retrieve in a single select query
+ progress: The update's progress dict.
+ batch_size: The batch size for this update.
Returns:
- The number of deleted rows
+ The number of rows deleted.
"""
- def _remove_deleted_devices_from_device_inbox_txn(
+ def _remove_dead_devices_from_device_inbox_txn(
txn: LoggingTransaction,
- ) -> int:
- """stream_id is not unique
- we need to use an inclusive `stream_id >= ?` clause,
- since we might not have deleted all dead device messages for the stream_id
- returned from the previous query
+ ) -> Tuple[int, bool]:
- Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
- to avoid problems of deleting a large number of rows all at once
- due to a single device having lots of device messages.
- """
+ if "max_stream_id" in progress:
+ max_stream_id = progress["max_stream_id"]
+ else:
+ txn.execute("SELECT max(stream_id) FROM device_inbox")
+ # There's a type mismatch here between how we want to type the row and
+ # what fetchone says it returns, but we silence it because we know that
+ # res can't be None.
+ res: Tuple[Optional[int]] = txn.fetchone() # type: ignore[assignment]
+ if res[0] is None:
+ # this can only happen if the `device_inbox` table is empty, in which
+ # case we have no work to do.
+ return 0, True
+ else:
+ max_stream_id = res[0]
- last_stream_id = progress.get("stream_id", 0)
+ start = progress.get("stream_id", 0)
+ stop = start + batch_size
+ # delete rows in `device_inbox` which do *not* correspond to a known,
+ # unhidden device.
sql = """
- SELECT device_id, user_id, stream_id
- FROM device_inbox
+ DELETE FROM device_inbox
WHERE
- stream_id >= ?
- AND (device_id, user_id) NOT IN (
- SELECT device_id, user_id FROM devices
+ stream_id >= ? AND stream_id < ?
+ AND NOT EXISTS (
+ SELECT * FROM devices d
+ WHERE
+ d.device_id=device_inbox.device_id
+ AND d.user_id=device_inbox.user_id
+ AND NOT hidden
)
- ORDER BY stream_id
- LIMIT ?
- """
-
- txn.execute(sql, (last_stream_id, batch_size))
- rows = txn.fetchall()
+ """
- num_deleted = 0
- for row in rows:
- num_deleted += self.db_pool.simple_delete_txn(
- txn,
- "device_inbox",
- {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
- )
+ txn.execute(sql, (start, stop))
- if rows:
- # send more than stream_id to progress
- # otherwise it can happen in large deployments that
- # no change of status is visible in the log file
- # it may be that the stream_id does not change in several runs
- self.db_pool.updates._background_update_progress_txn(
- txn,
- self.REMOVE_DELETED_DEVICES,
- {
- "device_id": rows[-1][0],
- "user_id": rows[-1][1],
- "stream_id": rows[-1][2],
- },
- )
-
- return num_deleted
-
- number_deleted = await self.db_pool.runInteraction(
- "_remove_deleted_devices_from_device_inbox",
- _remove_deleted_devices_from_device_inbox_txn,
- )
-
- # The task is finished when no more lines are deleted.
- if not number_deleted:
- await self.db_pool.updates._end_background_update(
- self.REMOVE_DELETED_DEVICES
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ self.REMOVE_DEAD_DEVICES_FROM_INBOX,
+ {
+ "stream_id": stop,
+ "max_stream_id": max_stream_id,
+ },
)
- return number_deleted
-
- async def _remove_hidden_devices_from_device_inbox(
- self, progress: JsonDict, batch_size: int
- ) -> int:
- """A background update that deletes all device_inboxes for hidden devices.
-
- This should only need to be run once (when users upgrade to v1.47.0)
-
- Args:
- progress: JsonDict used to store progress of this background update
- batch_size: the maximum number of rows to retrieve in a single select query
-
- Returns:
- The number of deleted rows
- """
-
- def _remove_hidden_devices_from_device_inbox_txn(
- txn: LoggingTransaction,
- ) -> int:
- """stream_id is not unique
- we need to use an inclusive `stream_id >= ?` clause,
- since we might not have deleted all hidden device messages for the stream_id
- returned from the previous query
-
- Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
- to avoid problems of deleting a large number of rows all at once
- due to a single device having lots of device messages.
- """
-
- last_stream_id = progress.get("stream_id", 0)
-
- sql = """
- SELECT device_id, user_id, stream_id
- FROM device_inbox
- WHERE
- stream_id >= ?
- AND (device_id, user_id) IN (
- SELECT device_id, user_id FROM devices WHERE hidden = ?
- )
- ORDER BY stream_id
- LIMIT ?
- """
-
- txn.execute(sql, (last_stream_id, True, batch_size))
- rows = txn.fetchall()
-
- num_deleted = 0
- for row in rows:
- num_deleted += self.db_pool.simple_delete_txn(
- txn,
- "device_inbox",
- {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
- )
-
- if rows:
- # We don't just save the `stream_id` in progress as
- # otherwise it can happen in large deployments that
- # no change of status is visible in the log file, as
- # it may be that the stream_id does not change in several runs
- self.db_pool.updates._background_update_progress_txn(
- txn,
- self.REMOVE_HIDDEN_DEVICES,
- {
- "device_id": rows[-1][0],
- "user_id": rows[-1][1],
- "stream_id": rows[-1][2],
- },
- )
-
- return num_deleted
+ return stop > max_stream_id
- number_deleted = await self.db_pool.runInteraction(
- "_remove_hidden_devices_from_device_inbox",
- _remove_hidden_devices_from_device_inbox_txn,
+ finished = await self.db_pool.runInteraction(
+ "_remove_devices_from_device_inbox_txn",
+ _remove_dead_devices_from_device_inbox_txn,
)
- # The task is finished when no more lines are deleted.
- if not number_deleted:
+ if finished:
await self.db_pool.updates._end_background_update(
- self.REMOVE_HIDDEN_DEVICES
+ self.REMOVE_DEAD_DEVICES_FROM_INBOX,
)
- return number_deleted
+ return batch_size
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index a95ac34f09..b06c1dc45b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -408,29 +408,58 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
fallback_keys: the keys to set. This is a map from key ID (which is
of the form "algorithm:id") to key data.
"""
+ await self.db_pool.runInteraction(
+ "set_e2e_fallback_keys_txn",
+ self._set_e2e_fallback_keys_txn,
+ user_id,
+ device_id,
+ fallback_keys,
+ )
+
+ await self.invalidate_cache_and_stream(
+ "get_e2e_unused_fallback_key_types", (user_id, device_id)
+ )
+
+ def _set_e2e_fallback_keys_txn(
+ self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
+ ) -> None:
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
for key_id, fallback_key in fallback_keys.items():
algorithm, key_id = key_id.split(":", 1)
- await self.db_pool.simple_upsert(
- "e2e_fallback_keys_json",
+ old_key_json = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
- values={
- "key_id": key_id,
- "key_json": json_encoder.encode(fallback_key),
- "used": False,
- },
- desc="set_e2e_fallback_key",
+ retcol="key_json",
+ allow_none=True,
)
- await self.invalidate_cache_and_stream(
- "get_e2e_unused_fallback_key_types", (user_id, device_id)
- )
+ new_key_json = encode_canonical_json(fallback_key).decode("utf-8")
+
+ # If the uploaded key is the same as the current fallback key,
+ # don't do anything. This prevents marking the key as unused if it
+ # was already used.
+ if old_key_json != new_key_json:
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="e2e_fallback_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ },
+ values={
+ "key_id": key_id,
+ "key_json": json_encoder.encode(fallback_key),
+ "used": False,
+ },
+ )
@cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 120e4807d1..c3440de2cb 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1,6 +1,6 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
# limitations under the License.
import itertools
import logging
-from collections import OrderedDict, namedtuple
+from collections import OrderedDict
from typing import (
TYPE_CHECKING,
Any,
@@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
@@ -64,9 +65,6 @@ event_counter = Counter(
)
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-
-
@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -108,16 +106,21 @@ class PersistEventsStore:
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
- # Ideally we'd move these ID gens here, unfortunately some other ID
- # generators are chained off them so doing so is a bit of a PITA.
- self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
- self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
-
# This should only exist on instances that are configured to write
assert (
hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
+ # Since we have been configured to write, we ought to have id generators,
+ # rather than id trackers.
+ assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
+ assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
+
+ # Ideally we'd move these ID gens here, unfortunately some other ID
+ # generators are chained off them so doing so is a bit of a PITA.
+ self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
+ self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
@@ -1553,11 +1556,13 @@ class PersistEventsStore:
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
- to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
+ to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
def prefill():
for cache_entry in to_prefill:
- self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
+ self.store._get_event_cache.set(
+ (cache_entry.event.event_id,), cache_entry
+ )
txn.call_after(prefill)
@@ -1696,34 +1701,33 @@ class PersistEventsStore:
},
)
- def _handle_event_relations(self, txn, event):
- """Handles inserting relation data during peristence of events
+ def _handle_event_relations(
+ self, txn: LoggingTransaction, event: EventBase
+ ) -> None:
+ """Handles inserting relation data during persistence of events
Args:
- txn
- event (EventBase)
+ txn: The current database transaction.
+ event: The event which might have relations.
"""
relation = event.content.get("m.relates_to")
if not relation:
# No relations
return
+ # Relations must have a type and parent event ID.
rel_type = relation.get("rel_type")
- if rel_type not in (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.REPLACE,
- RelationTypes.THREAD,
- ):
- # Unknown relation type
+ if not isinstance(rel_type, str):
return
parent_id = relation.get("event_id")
- if not parent_id:
- # Invalid relation
+ if not isinstance(parent_id, str):
return
- aggregation_key = relation.get("key")
+ # Annotations have a key field.
+ aggregation_key = None
+ if rel_type == RelationTypes.ANNOTATION:
+ aggregation_key = relation.get("key")
self.db_pool.simple_insert_txn(
txn,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index ae3a8a63e4..c88fd35e7f 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1,4 +1,4 @@
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -171,8 +171,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._purged_chain_cover_index,
)
+ # The event_thread_relation background update was replaced with the
+ # event_arbitrary_relations one, which handles any relation to avoid
+ # needed to potentially crawl the entire events table in the future.
+ self.db_pool.updates.register_noop_background_update("event_thread_relation")
+
self.db_pool.updates.register_background_update_handler(
- "event_thread_relation", self._event_thread_relation
+ "event_arbitrary_relations",
+ self._event_arbitrary_relations,
)
################################################################################
@@ -1099,23 +1105,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result
- async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int:
- """Background update handler which will store thread relations for existing events."""
+ async def _event_arbitrary_relations(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Background update handler which will store previously unknown relations for existing events."""
last_event_id = progress.get("last_event_id", "")
- def _event_thread_relation_txn(txn: LoggingTransaction) -> int:
+ def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
+ # Fetch events and then filter based on whether the event has a
+ # relation or not.
txn.execute(
"""
SELECT event_id, json FROM event_json
- LEFT JOIN event_relations USING (event_id)
- WHERE event_id > ? AND event_relations.event_id IS NULL
+ WHERE event_id > ?
ORDER BY event_id LIMIT ?
""",
(last_event_id, batch_size),
)
results = list(txn)
- missing_thread_relations = []
+ # (event_id, parent_id, rel_type) for each relation
+ relations_to_insert: List[Tuple[str, str, str]] = []
for (event_id, event_json_raw) in results:
try:
event_json = db_to_json(event_json_raw)
@@ -1127,48 +1137,70 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
continue
- # If there's no relation (or it is not a thread), skip!
+ # If there's no relation, skip!
relates_to = event_json["content"].get("m.relates_to")
if not relates_to or not isinstance(relates_to, dict):
continue
- if relates_to.get("rel_type") != RelationTypes.THREAD:
+
+ # If the relation type or parent event ID is not a string, skip it.
+ #
+ # Do not consider relation types that have existed for a long time,
+ # since they will already be listed in the `event_relations` table.
+ rel_type = relates_to.get("rel_type")
+ if not isinstance(rel_type, str) or rel_type in (
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ RelationTypes.REPLACE,
+ ):
continue
- # Get the parent ID.
parent_id = relates_to.get("event_id")
if not isinstance(parent_id, str):
continue
- missing_thread_relations.append((event_id, parent_id))
+ relations_to_insert.append((event_id, parent_id, rel_type))
+
+ # Insert the missing data, note that we upsert here in case the event
+ # has already been processed.
+ if relations_to_insert:
+ self.db_pool.simple_upsert_many_txn(
+ txn=txn,
+ table="event_relations",
+ key_names=("event_id",),
+ key_values=[(r[0],) for r in relations_to_insert],
+ value_names=("relates_to_id", "relation_type"),
+ value_values=[r[1:] for r in relations_to_insert],
+ )
- # Insert the missing data.
- self.db_pool.simple_insert_many_txn(
- txn=txn,
- table="event_relations",
- values=[
- {
- "event_id": event_id,
- "relates_to_Id": parent_id,
- "relation_type": RelationTypes.THREAD,
- }
- for event_id, parent_id in missing_thread_relations
- ],
- )
+ # Iterate the parent IDs and invalidate caches.
+ for parent_id in {r[1] for r in relations_to_insert}:
+ cache_tuple = (parent_id,)
+ self._invalidate_cache_and_stream(
+ txn, self.get_relations_for_event, cache_tuple
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_aggregation_groups_for_event, cache_tuple
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_thread_summary, cache_tuple
+ )
if results:
latest_event_id = results[-1][0]
self.db_pool.updates._background_update_progress_txn(
- txn, "event_thread_relation", {"last_event_id": latest_event_id}
+ txn, "event_arbitrary_relations", {"last_event_id": latest_event_id}
)
return len(results)
num_rows = await self.db_pool.runInteraction(
- desc="event_thread_relation", func=_event_thread_relation_txn
+ desc="event_arbitrary_relations", func=_event_arbitrary_relations_txn
)
if not num_rows:
- await self.db_pool.updates._end_background_update("event_thread_relation")
+ await self.db_pool.updates._end_background_update(
+ "event_arbitrary_relations"
+ )
return num_rows
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c6bf316d5b..4cefc0a07e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -15,14 +15,18 @@
import logging
import threading
from typing import (
+ TYPE_CHECKING,
+ Any,
Collection,
Container,
Dict,
Iterable,
List,
+ NoReturn,
Optional,
Set,
Tuple,
+ cast,
overload,
)
@@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
+ RoomVersion,
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
@@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdTracker,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
@@ -69,10 +82,13 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# These values are used in the `enqueue_event` and `_fetch_loop` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
# on jki.re
@@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
@attr.s(slots=True, auto_attribs=True)
-class _EventCacheEntry:
+class EventCacheEntry:
event: EventBase
redacted_event: Optional[EventBase]
@@ -129,7 +145,7 @@ class _EventRow:
json: str
internal_metadata: str
format_version: Optional[int]
- room_version_id: Optional[int]
+ room_version_id: Optional[str]
rejected_reason: Optional[str]
redactions: List[str]
outlier: bool
@@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
# options controlling this.
USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
+ self._stream_id_gen: AbstractStreamIdTracker
+ self._backfill_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
@@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
5 * 60 * 1000,
)
- self._get_event_cache = LruCache(
+ self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
)
@@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
# ID to cache entry. Note that the returned dict may not have the
# requested event in it if the event isn't in the DB.
self._current_event_fetches: Dict[
- str, ObservableDeferred[Dict[str, _EventCacheEntry]]
+ str, ObservableDeferred[Dict[str, EventCacheEntry]]
] = {}
self._event_fetch_lock = threading.Condition()
- self._event_fetch_list = []
+ self._event_fetch_list: List[
+ Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
+ ] = []
self._event_fetch_ongoing = 0
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
- def get_chain_id_txn(txn):
+ def get_chain_id_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
- return txn.fetchone()[0]
+ return cast(Tuple[int], txn.fetchone())[0]
self.event_chain_id_gen = build_sequence_generator(
db_conn,
@@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id",
)
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
@@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = False,
- allow_rejected: bool = False,
- allow_none: Literal[False] = False,
- check_room_id: Optional[str] = None,
+ get_prev_content: bool = ...,
+ allow_rejected: bool = ...,
+ allow_none: Literal[False] = ...,
+ check_room_id: Optional[str] = ...,
) -> EventBase:
...
@@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = False,
- allow_rejected: bool = False,
- allow_none: Literal[True] = False,
- check_room_id: Optional[str] = None,
+ get_prev_content: bool = ...,
+ allow_rejected: bool = ...,
+ allow_none: Literal[True] = ...,
+ check_room_id: Optional[str] = ...,
) -> Optional[EventBase]:
...
@@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_events(
self,
- event_ids: Iterable[str],
+ event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
@@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
- ) -> Dict[str, _EventCacheEntry]:
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
# same dict into itself N times).
already_fetching_ids: Set[str] = set()
already_fetching_deferreds: Set[
- ObservableDeferred[Dict[str, _EventCacheEntry]]
+ ObservableDeferred[Dict[str, EventCacheEntry]]
] = set()
for event_id in missing_events_ids:
@@ -601,8 +632,8 @@ class EventsWorkerStore(SQLBaseStore):
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
- Dict[str, _EventCacheEntry]
- ] = ObservableDeferred(defer.Deferred())
+ Dict[str, EventCacheEntry]
+ ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
@@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
- def _invalidate_get_event_cache(self, event_id):
+ def _invalidate_get_event_cache(self, event_id: str) -> None:
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
- ) -> Dict[str, _EventCacheEntry]:
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch events from the caches.
May return rejected events.
@@ -736,38 +767,123 @@ class EventsWorkerStore(SQLBaseStore):
for e in state_to_include.values()
]
- def _do_fetch(self, conn: Connection) -> None:
+ def _maybe_start_fetch_thread(self) -> None:
+ """Starts an event fetch thread if we are not yet at the maximum number."""
+ with self._event_fetch_lock:
+ if (
+ self._event_fetch_list
+ and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
+ ):
+ self._event_fetch_ongoing += 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+ # `_event_fetch_ongoing` is decremented in `_fetch_thread`.
+ should_start = True
+ else:
+ should_start = False
+
+ if should_start:
+ run_as_background_process("fetch_events", self._fetch_thread)
+
+ async def _fetch_thread(self) -> None:
+ """Services requests for events from `_event_fetch_list`."""
+ exc = None
+ try:
+ await self.db_pool.runWithConnection(self._fetch_loop)
+ except BaseException as e:
+ exc = e
+ raise
+ finally:
+ should_restart = False
+ event_fetches_to_fail = []
+ with self._event_fetch_lock:
+ self._event_fetch_ongoing -= 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+
+ # There may still be work remaining in `_event_fetch_list` if we
+ # failed, or it was added in between us deciding to exit and
+ # decrementing `_event_fetch_ongoing`.
+ if self._event_fetch_list:
+ if exc is None:
+ # We decided to exit, but then some more work was added
+ # before `_event_fetch_ongoing` was decremented.
+ # If a new event fetch thread was not started, we should
+ # restart ourselves since the remaining event fetch threads
+ # may take a while to get around to the new work.
+ #
+ # Unfortunately it is not possible to tell whether a new
+ # event fetch thread was started, so we restart
+ # unconditionally. If we are unlucky, we will end up with
+ # an idle fetch thread, but it will time out after
+ # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
+ # in any case.
+ #
+ # Note that multiple fetch threads may run down this path at
+ # the same time.
+ should_restart = True
+ elif isinstance(exc, Exception):
+ if self._event_fetch_ongoing == 0:
+ # We were the last remaining fetcher and failed.
+ # Fail any outstanding fetches since no one else will
+ # handle them.
+ event_fetches_to_fail = self._event_fetch_list
+ self._event_fetch_list = []
+ else:
+ # We weren't the last remaining fetcher, so another
+ # fetcher will pick up the work. This will either happen
+ # after their existing work, however long that takes,
+ # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
+ # they are idle.
+ pass
+ else:
+ # The exception is a `SystemExit`, `KeyboardInterrupt` or
+ # `GeneratorExit`. Don't try to do anything clever here.
+ pass
+
+ if should_restart:
+ # We exited cleanly but noticed more work.
+ self._maybe_start_fetch_thread()
+
+ if event_fetches_to_fail:
+ # We were the last remaining fetcher and failed.
+ # Fail any outstanding fetches since no one else will handle them.
+ assert exc is not None
+ with PreserveLoggingContext():
+ for _, deferred in event_fetches_to_fail:
+ deferred.errback(exc)
+
+ def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
- try:
- i = 0
- while True:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if (
- not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
- or single_threaded
- or i > EVENT_QUEUE_ITERATIONS
- ):
- break
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
+ i = 0
+ while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ # There are no requests waiting. If we haven't yet reached the
+ # maximum iteration limit, wait for some more requests to turn up.
+ # Otherwise, bail out.
+ single_threaded = self.database_engine.single_threaded
+ if (
+ not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+ or single_threaded
+ or i > EVENT_QUEUE_ITERATIONS
+ ):
+ return
+
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
- self._fetch_event_list(conn, event_list)
- finally:
- self._event_fetch_ongoing -= 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+ self._fetch_event_list(conn, event_list)
def _fetch_event_list(
- self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+ self,
+ conn: LoggingDatabaseConnection,
+ event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
) -> None:
"""Handle a load of requests from the _event_fetch_list queue
@@ -794,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
)
# We only want to resolve deferreds from the main thread
- def fire():
+ def fire() -> None:
for _, d in event_list:
d.callback(row_dict)
@@ -804,18 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
- def fire(evs, exc):
- for _, d in evs:
- if not d.called:
- with PreserveLoggingContext():
- d.errback(exc)
+ def fire_errback(exc: Exception) -> None:
+ for _, d in event_list:
+ d.errback(exc)
with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list, e)
+ self.hs.get_reactor().callFromThread(fire_errback, e)
async def _get_events_from_db(
- self, event_ids: Iterable[str]
- ) -> Dict[str, _EventCacheEntry]:
+ self, event_ids: Collection[str]
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the database.
May return rejected events.
@@ -831,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
map from event id to result. May return extra events which
weren't asked for.
"""
- fetched_events = {}
+ fetched_event_ids: Set[str] = set()
+ fetched_events: Dict[str, _EventRow] = {}
events_to_fetch = event_ids
while events_to_fetch:
row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
- redaction_ids = set()
+ redaction_ids: Set[str] = set()
for event_id in events_to_fetch:
row = row_map.get(event_id)
- fetched_events[event_id] = row
+ fetched_event_ids.add(event_id)
if row:
+ fetched_events[event_id] = row
redaction_ids.update(row.redactions)
- events_to_fetch = redaction_ids.difference(fetched_events.keys())
+ events_to_fetch = redaction_ids.difference(fetched_event_ids)
if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch)
# build a map from event_id to EventBase
- event_map = {}
+ event_map: Dict[str, EventBase] = {}
for event_id, row in fetched_events.items():
- if not row:
- continue
assert row.event_id == event_id
rejected_reason = row.rejected_reason
@@ -881,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row.room_version_id
+ room_version: Optional[RoomVersion]
if not room_version_id:
# this should only happen for out-of-band membership events which
# arrived before #6983 landed. For all other events, we should have
@@ -951,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
# finally, we can decide whether each one needs redacting, and build
# the cache entries.
- result_map = {}
+ result_map: Dict[str, EventCacheEntry] = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
- cache_entry = _EventCacheEntry(
+ cache_entry = EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)
@@ -967,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
+ async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -980,23 +1095,12 @@ class EventsWorkerStore(SQLBaseStore):
that weren't requested.
"""
- events_d = defer.Deferred()
+ events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append((events, events_d))
-
self._event_fetch_lock.notify()
- if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
- self._event_fetch_ongoing += 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
- should_start = True
- else:
- should_start = False
-
- if should_start:
- run_as_background_process(
- "fetch_events", self.db_pool.runWithConnection, self._do_fetch
- )
+ self._maybe_start_fetch_thread()
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
@@ -1146,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
- async def have_events_in_timeline(self, event_ids):
+ async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
@@ -1175,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
event_ids: events we are looking for
Returns:
- set[str]: The events we have already seen.
+ The set of events we have already seen.
"""
res = await self._have_seen_events_dict(
(room_id, event_id) for event_id in event_ids
@@ -1198,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
}
results = {x: True for x in cache_results}
- def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
+ def have_seen_events_txn(
+ txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
+ ) -> None:
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@@ -1224,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
- async def have_seen_event(self, room_id: str, event_id: str):
+ async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
# this only exists for the benefit of the @cachedList descriptor on
# _have_seen_events_dict
raise NotImplementedError()
- def _get_current_state_event_counts_txn(self, txn, room_id):
+ def _get_current_state_event_counts_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> int:
"""
See get_current_state_event_counts.
"""
@@ -1254,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
- async def get_room_complexity(self, room_id):
+ async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@@ -1262,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
more resources.
Args:
- room_id (str)
+ room_id: The room ID to query.
Returns:
- dict[str:int] of complexity version to complexity.
+ dict[str:float] of complexity version to complexity.
"""
state_events = await self.get_current_state_event_counts(room_id)
@@ -1275,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
return {"v1": complexity_v1}
- def get_current_events_token(self):
+ def get_current_events_token(self) -> int:
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> List[Tuple]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns new events, for the Events replication stream
Args:
@@ -1295,7 +1403,9 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_all_new_forward_event_rows(txn):
+ def get_all_new_forward_event_rows(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
@@ -1311,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, instance_name, limit))
- return txn.fetchall()
+ return cast(
+ List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
@@ -1319,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_ex_outlier_stream_rows(
self, instance_name: str, last_id: int, current_id: int
- ) -> List[Tuple]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns de-outliered events, for the Events replication stream
Args:
@@ -1332,7 +1444,9 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_ex_outlier_stream_rows_txn(txn):
+ def get_ex_outlier_stream_rows_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
@@ -1350,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (last_id, current_id, instance_name))
- return txn.fetchall()
+ return cast(
+ List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@@ -1358,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_new_backfill_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, list]], int, bool]:
+ ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
"""Get updates for backfill replication stream, including all new
backfilled events and events that have gone from being outliers to not.
@@ -1386,7 +1502,9 @@ class EventsWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_new_backfill_event_rows(txn):
+ def get_all_new_backfill_event_rows(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
@@ -1400,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
- new_event_updates = [(row[0], row[1:]) for row in txn]
+ new_event_updates: List[
+ Tuple[int, Tuple[str, str, str, str, str, str]]
+ ] = []
+ row: Tuple[int, str, str, str, str, str, str]
+ # Type safety: iterating over `txn` yields `Tuple`, i.e.
+ # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+ # variadic tuple to a fixed length tuple and flags it up as an error.
+ for row in txn: # type: ignore[assignment]
+ new_event_updates.append((row[0], row[1:]))
limited = False
if len(new_event_updates) == limit:
@@ -1423,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound, instance_name))
- new_event_updates.extend((row[0], row[1:]) for row in txn)
+ # Type safety: iterating over `txn` yields `Tuple`, i.e.
+ # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+ # variadic tuple to a fixed length tuple and flags it up as an error.
+ for row in txn: # type: ignore[assignment]
+ new_event_updates.append((row[0], row[1:]))
if len(new_event_updates) >= limit:
upper_bound = new_event_updates[-1][0]
@@ -1437,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_updated_current_state_deltas(
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
- ) -> Tuple[List[Tuple], int, bool]:
+ ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
"""Fetch updates from current_state_delta_stream
Args:
@@ -1457,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
* `limited` is whether there are more updates to fetch.
"""
- def get_all_updated_current_state_deltas_txn(txn):
+ def get_all_updated_current_state_deltas_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
@@ -1466,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
- def get_deltas_for_stream_id_txn(txn, stream_id):
+ def get_deltas_for_stream_id_txn(
+ txn: LoggingTransaction, stream_id: int
+ ) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
- rows: List[Tuple] = await self.db_pool.runInteraction(
+ rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
@@ -1509,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
- async def is_event_after(self, event_id1, event_id2):
+ async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
"""Returns True if event_id1 is after event_id2 in the stream"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cached(max_entries=5000)
- async def get_event_ordering(self, event_id):
+ async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
@@ -1539,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
None otherwise.
"""
- def get_next_event_to_expire_txn(txn):
+ def get_next_event_to_expire_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, int]]:
txn.execute(
"""
SELECT event_id, expiry_ts FROM event_expiry
@@ -1547,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
)
- return txn.fetchone()
+ return cast(Optional[Tuple[str, int]], txn.fetchone())
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@@ -1611,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
return mapping
@wrap_as_background_process("_cleanup_old_transaction_ids")
- async def _cleanup_old_transaction_ids(self):
+ async def _cleanup_old_transaction_ids(self) -> None:
"""Cleans out transaction id mappings older than 24hrs."""
- def _cleanup_old_transaction_ids_txn(txn):
+ def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index fa782023d4..3b63267395 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdTracker,
+ StreamIdGenerator,
+)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +85,9 @@ class PushRulesWorkerStore(
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen: Union[
- StreamIdGenerator, SlavedIdTracker
- ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
+ self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ db_conn, "push_rules_stream", "stream_id"
+ )
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 5e55440570..e1ddf06916 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -84,28 +84,37 @@ class TokenLookupResult:
return self.user_id
-@attr.s(frozen=True, slots=True)
+@attr.s(auto_attribs=True, frozen=True, slots=True)
class RefreshTokenLookupResult:
"""Result of looking up a refresh token."""
- user_id = attr.ib(type=str)
+ user_id: str
"""The user this token belongs to."""
- device_id = attr.ib(type=str)
+ device_id: str
"""The device associated with this refresh token."""
- token_id = attr.ib(type=int)
+ token_id: int
"""The ID of this refresh token."""
- next_token_id = attr.ib(type=Optional[int])
+ next_token_id: Optional[int]
"""The ID of the refresh token which replaced this one."""
- has_next_refresh_token_been_refreshed = attr.ib(type=bool)
+ has_next_refresh_token_been_refreshed: bool
"""True if the next refresh token was used for another refresh."""
- has_next_access_token_been_used = attr.ib(type=bool)
+ has_next_access_token_been_used: bool
"""True if the next access token was already used at least once."""
+ expiry_ts: Optional[int]
+ """The time at which the refresh token expires and can not be used.
+ If None, the refresh token doesn't expire."""
+
+ ultimate_session_expiry_ts: Optional[int]
+ """The time at which the session comes to an end and can no longer be
+ refreshed.
+ If None, the session can be refreshed indefinitely."""
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
@@ -1198,8 +1207,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
expiration_ts = now_ms + self._account_validity_period
if use_delta:
+ assert self._account_validity_startup_job_max_delta is not None
expiration_ts = random.randrange(
- expiration_ts - self._account_validity_startup_job_max_delta,
+ int(expiration_ts - self._account_validity_startup_job_max_delta),
expiration_ts,
)
@@ -1625,8 +1635,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
rt.user_id,
rt.device_id,
rt.next_token_id,
- (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
- at.used has_next_access_token_been_used
+ (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed,
+ at.used AS has_next_access_token_been_used,
+ rt.expiry_ts,
+ rt.ultimate_session_expiry_ts
FROM refresh_tokens rt
LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
@@ -1647,6 +1659,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
has_next_refresh_token_been_refreshed=row[4],
# This column is nullable, ensure it's a boolean
has_next_access_token_been_used=(row[5] or False),
+ expiry_ts=row[6],
+ ultimate_session_expiry_ts=row[7],
)
return await self.db_pool.runInteraction(
@@ -1728,11 +1742,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
)
self.db_pool.updates.register_background_update_handler(
- "user_threepids_grandfather", self._bg_user_threepids_grandfather
+ "users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
- self.db_pool.updates.register_background_update_handler(
- "users_set_deactivated_flag", self._background_update_set_deactivated_flag
+ self.db_pool.updates.register_noop_background_update(
+ "user_threepids_grandfather"
)
self.db_pool.updates.register_background_index_update(
@@ -1805,35 +1819,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
return nb_processed
- async def _bg_user_threepids_grandfather(self, progress, batch_size):
- """We now track which identity servers a user binds their 3PID to, so
- we need to handle the case of existing bindings where we didn't track
- this.
-
- We do this by grandfathering in existing user threepids assuming that
- they used one of the server configured trusted identity servers.
- """
- id_servers = set(self.config.registration.trusted_third_party_id_servers)
-
- def _bg_user_threepids_grandfather_txn(txn):
- sql = """
- INSERT INTO user_threepid_id_server
- (user_id, medium, address, id_server)
- SELECT user_id, medium, address, ?
- FROM user_threepids
- """
-
- txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
-
- if id_servers:
- await self.db_pool.runInteraction(
- "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
- )
-
- await self.db_pool.updates._end_background_update("user_threepids_grandfather")
-
- return 1
-
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
@@ -1943,6 +1928,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
user_id: str,
token: str,
device_id: Optional[str],
+ expiry_ts: Optional[int],
+ ultimate_session_expiry_ts: Optional[int],
) -> int:
"""Adds a refresh token for the given user.
@@ -1950,6 +1937,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
user_id: The user ID.
token: The new access token to add.
device_id: ID of the device to associate with the refresh token.
+ expiry_ts (milliseconds since the epoch): Time after which the
+ refresh token cannot be used.
+ If None, the refresh token never expires until it has been used.
+ ultimate_session_expiry_ts (milliseconds since the epoch):
+ Time at which the session will end and can not be extended any
+ further.
+ If None, the session can be refreshed indefinitely.
Raises:
StoreError if there was a problem adding this.
Returns:
@@ -1965,6 +1959,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"device_id": device_id,
"token": token,
"next_token_id": None,
+ "expiry_ts": expiry_ts,
+ "ultimate_session_expiry_ts": ultimate_session_expiry_ts,
},
desc="add_refresh_token_to_user",
)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 907af10995..0a43acda07 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -132,6 +132,69 @@ class RelationsWorkerStore(SQLBaseStore):
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
+ async def event_includes_relation(self, event_id: str) -> bool:
+ """Check if the given event relates to another event.
+
+ An event has a relation if it has a valid m.relates_to with a rel_type
+ and event_id in the content:
+
+ {
+ "content": {
+ "m.relates_to": {
+ "rel_type": "m.replace",
+ "event_id": "$other_event_id"
+ }
+ }
+ }
+
+ Args:
+ event_id: The event to check.
+
+ Returns:
+ True if the event includes a valid relation.
+ """
+
+ result = await self.db_pool.simple_select_one_onecol(
+ table="event_relations",
+ keyvalues={"event_id": event_id},
+ retcol="event_id",
+ allow_none=True,
+ desc="event_includes_relation",
+ )
+ return result is not None
+
+ async def event_is_target_of_relation(self, parent_id: str) -> bool:
+ """Check if the given event is the target of another event's relation.
+
+ An event is the target of an event relation if it has a valid
+ m.relates_to with a rel_type and event_id pointing to parent_id in the
+ content:
+
+ {
+ "content": {
+ "m.relates_to": {
+ "rel_type": "m.replace",
+ "event_id": "$parent_id"
+ }
+ }
+ }
+
+ Args:
+ parent_id: The event to check.
+
+ Returns:
+ True if the event is the target of another event's relation.
+ """
+
+ result = await self.db_pool.simple_select_one_onecol(
+ table="event_relations",
+ keyvalues={"relates_to_id": parent_id},
+ retcol="event_id",
+ allow_none=True,
+ desc="event_is_target_of_relation",
+ )
+ return result is not None
+
@cached(tree=True)
async def get_aggregation_groups_for_event(
self,
@@ -362,7 +425,7 @@ class RelationsWorkerStore(SQLBaseStore):
%s;
"""
- def _get_if_event_has_relations(txn) -> List[str]:
+ def _get_if_events_have_relations(txn) -> List[str]:
clauses: List[str] = []
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", parent_ids
@@ -387,7 +450,7 @@ class RelationsWorkerStore(SQLBaseStore):
return [row[0] for row in txn]
return await self.db_pool.runInteraction(
- "get_if_event_has_relations", _get_if_event_has_relations
+ "get_if_events_have_relations", _get_if_events_have_relations
)
async def has_user_annotated_event(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 17b398bb69..7d694d852d 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -397,6 +397,20 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ async def room_is_blocked_by(self, room_id: str) -> Optional[str]:
+ """
+ Function to retrieve user who has blocked the room.
+ user_id is non-nullable
+ It returns None if the room is not blocked.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ table="blocked_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="room_is_blocked_by",
+ )
+
async def get_rooms_paginate(
self,
start: int,
@@ -1775,3 +1789,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self.is_room_blocked,
(room_id,),
)
+
+ async def unblock_room(self, room_id: str) -> None:
+ """Remove the room from blocking list.
+
+ Args:
+ room_id: Room to unblock
+ """
+ await self.db_pool.simple_delete(
+ table="blocked_rooms",
+ keyvalues={"room_id": room_id},
+ desc="unblock_room",
+ )
+ await self.db_pool.runInteraction(
+ "block_room_invalidation",
+ self._invalidate_cache_and_stream,
+ self.is_room_blocked,
+ (room_id,),
+ )
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 8b9c6adae2..e45adfcb55 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -131,24 +131,16 @@ def prepare_database(
"config==None in prepare_database, but database is not empty"
)
- # if it's a worker app, refuse to upgrade the database, to avoid multiple
- # workers doing it at once.
- if config.worker.worker_app is None:
- _upgrade_existing_database(
- cur,
- version_info,
- database_engine,
- config,
- databases=databases,
- )
- elif version_info.current_version < SCHEMA_VERSION:
- # If the DB is on an older version than we expect then we refuse
- # to start the worker (as the main process needs to run first to
- # update the schema).
- raise UpgradeDatabaseException(
- OUTDATED_SCHEMA_ON_WORKER_ERROR
- % (SCHEMA_VERSION, version_info.current_version)
- )
+ # This should be run on all processes, master or worker. The master will
+ # apply the deltas, while workers will check if any outstanding deltas
+ # exist and raise an PrepareDatabaseException if they do.
+ _upgrade_existing_database(
+ cur,
+ version_info,
+ database_engine,
+ config,
+ databases=databases,
+ )
else:
logger.info("%r: Initialising new database", databases)
@@ -358,6 +350,18 @@ def _upgrade_existing_database(
is_worker = config and config.worker.worker_app is not None
+ # If the schema version needs to be updated, and we are on a worker, we immediately
+ # know to bail out as workers cannot update the database schema. Only one process
+ # must update the database at the time, therefore we delegate this task to the master.
+ if is_worker and current_schema_state.current_version < SCHEMA_VERSION:
+ # If the DB is on an older version than we expect then we refuse
+ # to start the worker (as the main process needs to run first to
+ # update the schema).
+ raise UpgradeDatabaseException(
+ OUTDATED_SCHEMA_ON_WORKER_ERROR
+ % (SCHEMA_VERSION, current_schema_state.current_version)
+ )
+
if (
current_schema_state.compat_version is not None
and current_schema_state.compat_version > SCHEMA_VERSION
diff --git a/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql b/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql
new file mode 100644
index 0000000000..82f6408b36
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/06remove_deleted_devices_from_device_inbox.sql
@@ -0,0 +1,34 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Remove messages from the device_inbox table which were orphaned
+-- when a device was deleted using Synapse earlier than 1.47.0.
+-- This runs as background task, but may take a bit to finish.
+
+-- Remove any existing instances of this job running. It's OK to stop and restart this job,
+-- as it's just deleting entries from a table - no progress will be lost.
+--
+-- This is necessary due a similar migration running the job accidentally
+-- being included in schema version 64 during v1.47.0rc1,rc2. If a
+-- homeserver had updated from Synapse <=v1.45.0 (schema version <=64),
+-- then they would have started running this background update already.
+-- If that update was still running, then simply inserting it again would
+-- cause an SQL failure. So we effectively do an "upsert" here instead.
+
+DELETE FROM background_updates WHERE update_name = 'remove_deleted_devices_from_device_inbox';
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (6506, 'remove_deleted_devices_from_device_inbox', '{}');
diff --git a/synapse/storage/schema/main/delta/65/02_thread_relations.sql b/synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql
index d60517f7b4..267b2cb539 100644
--- a/synapse/storage/schema/main/delta/65/02_thread_relations.sql
+++ b/synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql
@@ -15,4 +15,4 @@
-- Check old events for thread relations.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
- (6502, 'event_thread_relation', '{}');
+ (6507, 'event_arbitrary_relations', '{}');
diff --git a/synapse/storage/schema/main/delta/65/05remove_deleted_devices_from_device_inbox.sql b/synapse/storage/schema/main/delta/65/08_device_inbox_background_updates.sql
index 076179123d..d79455c2ce 100644
--- a/synapse/storage/schema/main/delta/65/05remove_deleted_devices_from_device_inbox.sql
+++ b/synapse/storage/schema/main/delta/65/08_device_inbox_background_updates.sql
@@ -13,10 +13,6 @@
* limitations under the License.
*/
-
--- Remove messages from the device_inbox table which were orphaned
--- when a device was deleted using Synapse earlier than 1.47.0.
--- This runs as background task, but may take a bit to finish.
-
+-- Background update to clear the inboxes of hidden and deleted devices.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
- (6505, 'remove_deleted_devices_from_device_inbox', '{}');
+ (6508, 'remove_dead_devices_from_device_inbox', '{}');
diff --git a/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql
new file mode 100644
index 0000000000..bdc491c817
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/10_expirable_refresh_tokens.sql
@@ -0,0 +1,28 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+ALTER TABLE refresh_tokens
+ -- We add an expiry_ts column (in milliseconds since the Epoch) to refresh tokens.
+ -- They may not be used after they have expired.
+ -- If null, then the refresh token's lifetime is unlimited.
+ ADD COLUMN expiry_ts BIGINT DEFAULT NULL;
+
+ALTER TABLE refresh_tokens
+ -- We also add an ultimate session expiry time (in milliseconds since the Epoch).
+ -- No matter how much the access and refresh tokens are refreshed, they cannot
+ -- be extended past this time.
+ -- If null, then the session length is unlimited.
+ ADD COLUMN ultimate_session_expiry_ts BIGINT DEFAULT NULL;
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ac56bc9a05..4ff3013908 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -89,31 +89,77 @@ def _load_current_id(
return (max if step > 0 else min)(current_id, step)
-class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
- @abc.abstractmethod
- def get_next(self) -> AsyncContextManager[int]:
- raise NotImplementedError()
+class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
+ """Tracks the "current" stream ID of a stream that may have multiple writers.
+
+ Stream IDs are monotonically increasing or decreasing integers representing write
+ transactions. The "current" stream ID is the stream ID such that all transactions
+ with equal or smaller stream IDs have completed. Since transactions may complete out
+ of order, this is not the same as the stream ID of the last completed transaction.
+
+ Completed transactions include both committed transactions and transactions that
+ have been rolled back.
+ """
@abc.abstractmethod
- def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+ def advance(self, instance_name: str, new_id: int) -> None:
+ """Advance the position of the named writer to the given ID, if greater
+ than existing entry.
+ """
raise NotImplementedError()
@abc.abstractmethod
def get_current_token(self) -> int:
+ """Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
+
+ Returns:
+ The maximum stream id.
+ """
raise NotImplementedError()
@abc.abstractmethod
def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
+
+ For streams with single writers this is equivalent to `get_current_token`.
+ """
+ raise NotImplementedError()
+
+
+class AbstractStreamIdGenerator(AbstractStreamIdTracker):
+ """Generates stream IDs for a stream that may have multiple writers.
+
+ Each stream ID represents a write transaction, whose completion is tracked
+ so that the "current" stream ID of the stream can be determined.
+
+ See `AbstractStreamIdTracker` for more details.
+ """
+
+ @abc.abstractmethod
+ def get_next(self) -> AsyncContextManager[int]:
+ """
+ Usage:
+ async with stream_id_gen.get_next() as stream_id:
+ # ... persist event ...
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+ """
+ Usage:
+ async with stream_id_gen.get_next(n) as stream_ids:
+ # ... persist events ...
+ """
raise NotImplementedError()
class StreamIdGenerator(AbstractStreamIdGenerator):
- """Used to generate new stream ids when persisting events while keeping
- track of which transactions have been completed.
+ """Generates and tracks stream IDs for a stream with a single writer.
- This allows us to get the "current" stream id, i.e. the stream id such that
- all ids less than or equal to it have completed. This handles the fact that
- persistence of events can complete out of order.
+ This class must only be used when the current Synapse process is the sole
+ writer for a stream.
Args:
db_conn(connection): A database connection to use to fetch the
@@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
+ def advance(self, instance_name: str, new_id: int) -> None:
+ # `StreamIdGenerator` should only be used when there is a single writer,
+ # so replication should never happen.
+ raise Exception("Replication is not supported by StreamIdGenerator")
+
def get_next(self) -> AsyncContextManager[int]:
- """
- Usage:
- async with stream_id_gen.get_next() as stream_id:
- # ... persist event ...
- """
with self._lock:
self._current += self._step
next_id = self._current
@@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
- """
- Usage:
- async with stream_id_gen.get_next(n) as stream_ids:
- # ... persist events ...
- """
with self._lock:
next_ids = range(
self._current + self._step,
@@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int:
- """Returns the maximum stream id such that all stream ids less than or
- equal to it have been successfully persisted.
-
- Returns:
- The maximum stream id.
- """
with self._lock:
if self._unfinished_ids:
return next(iter(self._unfinished_ids)) - self._step
@@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer.
-
- For streams with single writers this is equivalent to
- `get_current_token`.
- """
return self.get_current_token()
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
- """An ID generator that tracks a stream that can have multiple writers.
+ """Generates and tracks stream IDs for a stream with multiple writers.
Uses a Postgres sequence to coordinate ID assignment, but positions of other
writers will only get updated when `advance` is called (by replication).
@@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return stream_ids
def get_next(self) -> AsyncContextManager[int]:
- """
- Usage:
- async with stream_id_gen.get_next() as stream_id:
- # ... persist event ...
- """
-
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
- """
- Usage:
- async with stream_id_gen.get_next_mult(5) as stream_ids:
- # ... persist events ...
- """
-
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._add_persisted_position(next_id)
def get_current_token(self) -> int:
- """Returns the maximum stream id such that all stream ids less than or
- equal to it have been successfully persisted.
- """
-
return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer."""
-
# If we don't have an entry for the given instance name, we assume it's a
# new writer.
#
@@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
}
def advance(self, instance_name: str, new_id: int) -> None:
- """Advance the position of the named writer to the given ID, if greater
- than existing entry.
- """
-
new_id *= self._return_factor
with self._lock:
|