diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index be67d1ff22..a85633efcd 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -70,6 +70,7 @@ from .state import StateStore
from .stats import StatsStore
from .stream import StreamWorkerStore
from .tags import TagsStore
+from .task_scheduler import TaskSchedulerWorkerStore
from .transactions import TransactionWorkerStore
from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
@@ -127,6 +128,7 @@ class DataStore(
CacheInvalidationWorkerStore,
LockStore,
SessionStore,
+ TaskSchedulerWorkerStore,
):
def __init__(
self,
@@ -168,6 +170,7 @@ class DataStore(
name: Optional[str] = None,
guests: bool = True,
deactivated: bool = False,
+ admins: Optional[bool] = None,
order_by: str = UserSortOrder.NAME.value,
direction: Direction = Direction.FORWARDS,
approved: bool = True,
@@ -184,6 +187,9 @@ class DataStore(
name: search for local part of user_id or display name
guests: whether to in include guest users
deactivated: whether to include deactivated users
+ admins: Optional flag to filter admins. If true, only admins are queried.
+ if false, admins are excluded from the query. When it is
+ none (the default), both admins and none-admins are queried.
order_by: the sort order of the returned list
direction: sort ascending or descending
approved: whether to include approved users
@@ -220,6 +226,12 @@ class DataStore(
if not deactivated:
filters.append("deactivated = 0")
+ if admins is not None:
+ if admins:
+ filters.append("admin = 1")
+ else:
+ filters.append("admin = 0")
+
if not approved:
# We ignore NULL values for the approved flag because these should only
# be already existing users that we consider as already approved.
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2fbd389c71..18905e07b6 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -584,6 +584,19 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
else:
return 0
+ async def stream_introspection_token_invalidation(
+ self, key: Tuple[Optional[str]]
+ ) -> None:
+ """
+ Stream an invalidation request for the introspection token cache to workers
+
+ Args:
+ key: token_id of the introspection token to remove from the cache
+ """
+ await self.send_invalidation_to_replication(
+ "introspection_token_invalidation", key
+ )
+
@wrap_as_background_process("clean_up_old_cache_invalidations")
async def _clean_up_cache_invalidation_wrapper(self) -> None:
"""
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e4162f846b..fa69a4a298 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -33,6 +33,7 @@ from typing_extensions import Literal
from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, StoreError
+from synapse.config.homeserver import HomeServerConfig
from synapse.logging.opentracing import (
get_active_span_text_map,
set_tag,
@@ -1663,6 +1664,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.device_id_exists_cache: LruCache[
Tuple[str, str], Literal[True]
] = LruCache(cache_name="device_id_exists", max_size=10000)
+ self.config: HomeServerConfig = hs.config
async def store_device(
self,
@@ -1784,6 +1786,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
+ # TODO: don't nuke the entire cache once there is a way to associate
+ # device_id -> introspection_token
+ if self.config.experimental.msc3861.enabled:
+ # mypy ignore - the token cache is defined on MSC3861DelegatedAuth
+ self.auth._token_cache.invalidate_all() # type: ignore[attr-defined]
+ await self.stream_introspection_token_invalidation((None,))
+
async def update_device(
self, user_id: str, device_id: str, new_display_name: Optional[str] = None
) -> None:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 534dc32413..fab7008a8f 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -452,33 +452,56 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# sets.
seen_chains: Set[int] = set()
- sql = """
- SELECT event_id, chain_id, sequence_number
- FROM event_auth_chains
- WHERE %s
- """
- for batch in batch_iter(initial_events, 1000):
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "event_id", batch
- )
- txn.execute(sql % (clause,), args)
+ # Fetch the chain cover index for the initial set of events we're
+ # considering.
+ def fetch_chain_info(events_to_fetch: Collection[str]) -> None:
+ sql = """
+ SELECT event_id, chain_id, sequence_number
+ FROM event_auth_chains
+ WHERE %s
+ """
+ for batch in batch_iter(events_to_fetch, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", batch
+ )
+ txn.execute(sql % (clause,), args)
- for event_id, chain_id, sequence_number in txn:
- chain_info[event_id] = (chain_id, sequence_number)
- seen_chains.add(chain_id)
- chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
+ for event_id, chain_id, sequence_number in txn:
+ chain_info[event_id] = (chain_id, sequence_number)
+ seen_chains.add(chain_id)
+ chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
+
+ fetch_chain_info(initial_events)
# Check that we actually have a chain ID for all the events.
events_missing_chain_info = initial_events.difference(chain_info)
+
+ # The result set to return, i.e. the auth chain difference.
+ result: Set[str] = set()
+
if events_missing_chain_info:
- # This can happen due to e.g. downgrade/upgrade of the server. We
- # raise an exception and fall back to the previous algorithm.
- logger.info(
- "Unexpectedly found that events don't have chain IDs in room %s: %s",
+ # For some reason we have events we haven't calculated the chain
+ # index for, so we need to handle those separately. This should only
+ # happen for older rooms where the server doesn't have all the auth
+ # events.
+ result = self._fixup_auth_chain_difference_sets(
+ txn,
room_id,
- events_missing_chain_info,
+ state_sets=state_sets,
+ events_missing_chain_info=events_missing_chain_info,
+ events_that_have_chain_index=chain_info,
)
- raise _NoChainCoverIndex(room_id)
+
+ # We now need to refetch any events that we have added to the state
+ # sets.
+ new_events_to_fetch = {
+ event_id
+ for state_set in state_sets
+ for event_id in state_set
+ if event_id not in initial_events
+ }
+
+ fetch_chain_info(new_events_to_fetch)
# Corresponds to `state_sets`, except as a map from chain ID to max
# sequence number reachable from the state set.
@@ -487,8 +510,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
chains: Dict[int, int] = {}
set_to_chain.append(chains)
- for event_id in state_set:
- chain_id, seq_no = chain_info[event_id]
+ for state_id in state_set:
+ chain_id, seq_no = chain_info[state_id]
chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
@@ -532,7 +555,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# from *any* state set and the minimum sequence number reachable from
# *all* state sets. Events in that range are in the auth chain
# difference.
- result = set()
# Mapping from chain ID to the range of sequence numbers that should be
# pulled from the database.
@@ -588,6 +610,122 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return result
+ def _fixup_auth_chain_difference_sets(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ state_sets: List[Set[str]],
+ events_missing_chain_info: Set[str],
+ events_that_have_chain_index: Collection[str],
+ ) -> Set[str]:
+ """Helper for `_get_auth_chain_difference_using_cover_index_txn` to
+ handle the case where we haven't calculated the chain cover index for
+ all events.
+
+ This modifies `state_sets` so that they only include events that have a
+ chain cover index, and returns a set of event IDs that are part of the
+ auth difference.
+ """
+
+ # This works similarly to the handling of unpersisted events in
+ # `synapse.state.v2_get_auth_chain_difference`. We uses the observation
+ # that if you can split the set of events into two classes X and Y,
+ # where no events in Y have events in X in their auth chain, then we can
+ # calculate the auth difference by considering X and Y separately.
+ #
+ # We do this in three steps:
+ # 1. Compute the set of events without chain cover index belonging to
+ # the auth difference.
+ # 2. Replacing the un-indexed events in the state_sets with their auth
+ # events, recursively, until the state_sets contain only indexed
+ # events. We can then calculate the auth difference of those state
+ # sets using the chain cover index.
+ # 3. Add the results of 1 and 2 together.
+
+ # By construction we know that all events that we haven't persisted the
+ # chain cover index for are contained in
+ # `event_auth_chain_to_calculate`, so we pull out the events from those
+ # rather than doing recursive queries to walk the auth chain.
+ #
+ # We pull out those events with their auth events, which gives us enough
+ # information to construct the auth chain of an event up to auth events
+ # that have the chain cover index.
+ sql = """
+ SELECT tc.event_id, ea.auth_id, eac.chain_id IS NOT NULL
+ FROM event_auth_chain_to_calculate AS tc
+ LEFT JOIN event_auth AS ea USING (event_id)
+ LEFT JOIN event_auth_chains AS eac ON (ea.auth_id = eac.event_id)
+ WHERE tc.room_id = ?
+ """
+ txn.execute(sql, (room_id,))
+ event_to_auth_ids: Dict[str, Set[str]] = {}
+ events_that_have_chain_index = set(events_that_have_chain_index)
+ for event_id, auth_id, auth_id_has_chain in txn:
+ s = event_to_auth_ids.setdefault(event_id, set())
+ if auth_id is not None:
+ s.add(auth_id)
+ if auth_id_has_chain:
+ events_that_have_chain_index.add(auth_id)
+
+ if events_missing_chain_info - event_to_auth_ids.keys():
+ # Uh oh, we somehow haven't correctly done the chain cover index,
+ # bail and fall back to the old method.
+ logger.info(
+ "Unexpectedly found that events don't have chain IDs in room %s: %s",
+ room_id,
+ events_missing_chain_info - event_to_auth_ids.keys(),
+ )
+ raise _NoChainCoverIndex(room_id)
+
+ # Create a map from event IDs we care about to their partial auth chain.
+ event_id_to_partial_auth_chain: Dict[str, Set[str]] = {}
+ for event_id, auth_ids in event_to_auth_ids.items():
+ if not any(event_id in state_set for state_set in state_sets):
+ continue
+
+ processing = set(auth_ids)
+ to_add = set()
+ while processing:
+ auth_id = processing.pop()
+ to_add.add(auth_id)
+
+ sub_auth_ids = event_to_auth_ids.get(auth_id)
+ if sub_auth_ids is None:
+ continue
+
+ processing.update(sub_auth_ids - to_add)
+
+ event_id_to_partial_auth_chain[event_id] = to_add
+
+ # Now we do two things:
+ # 1. Update the state sets to only include indexed events; and
+ # 2. Create a new list containing the auth chains of the un-indexed
+ # events
+ unindexed_state_sets: List[Set[str]] = []
+ for state_set in state_sets:
+ unindexed_state_set = set()
+ for event_id, auth_chain in event_id_to_partial_auth_chain.items():
+ if event_id not in state_set:
+ continue
+
+ unindexed_state_set.add(event_id)
+
+ state_set.discard(event_id)
+ state_set.difference_update(auth_chain)
+ for auth_id in auth_chain:
+ if auth_id in events_that_have_chain_index:
+ state_set.add(auth_id)
+ else:
+ unindexed_state_set.add(auth_id)
+
+ unindexed_state_sets.append(unindexed_state_set)
+
+ # Calculate and return the auth difference of the un-indexed events.
+ union = unindexed_state_sets[0].union(*unindexed_state_sets[1:])
+ intersection = unindexed_state_sets[0].intersection(*unindexed_state_sets[1:])
+
+ return union - intersection
+
def _get_auth_chain_difference_txn(
self, txn: LoggingTransaction, state_sets: List[Set[str]]
) -> Set[str]:
diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py
new file mode 100644
index 0000000000..1fb3180c3c
--- /dev/null
+++ b/synapse/storage/databases/main/task_scheduler.py
@@ -0,0 +1,202 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
+from synapse.types import JsonDict, JsonMapping, ScheduledTask, TaskStatus
+from synapse.util import json_encoder
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+class TaskSchedulerWorkerStore(SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ @staticmethod
+ def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask:
+ row["status"] = TaskStatus(row["status"])
+ if row["params"] is not None:
+ row["params"] = db_to_json(row["params"])
+ if row["result"] is not None:
+ row["result"] = db_to_json(row["result"])
+ return ScheduledTask(**row)
+
+ async def get_scheduled_tasks(
+ self,
+ *,
+ actions: Optional[List[str]] = None,
+ resource_id: Optional[str] = None,
+ statuses: Optional[List[TaskStatus]] = None,
+ max_timestamp: Optional[int] = None,
+ ) -> List[ScheduledTask]:
+ """Get a list of scheduled tasks from the DB.
+
+ Args:
+ actions: Limit the returned tasks to those specific action names
+ resource_id: Limit the returned tasks to the specific resource id, if specified
+ statuses: Limit the returned tasks to the specific statuses
+ max_timestamp: Limit the returned tasks to the ones that have
+ a timestamp inferior to the specified one
+
+ Returns: a list of `ScheduledTask`, ordered by increasing timestamps
+ """
+
+ def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+ clauses: List[str] = []
+ args: List[Any] = []
+ if resource_id:
+ clauses.append("resource_id = ?")
+ args.append(resource_id)
+ if actions is not None:
+ clause, temp_args = make_in_list_sql_clause(
+ txn.database_engine, "action", actions
+ )
+ clauses.append(clause)
+ args.extend(temp_args)
+ if statuses is not None:
+ clause, temp_args = make_in_list_sql_clause(
+ txn.database_engine, "status", statuses
+ )
+ clauses.append(clause)
+ args.extend(temp_args)
+ if max_timestamp is not None:
+ clauses.append("timestamp <= ?")
+ args.append(max_timestamp)
+
+ sql = "SELECT * FROM scheduled_tasks"
+ if clauses:
+ sql = sql + " WHERE " + " AND ".join(clauses)
+
+ sql = sql + "ORDER BY timestamp"
+
+ txn.execute(sql, args)
+ return self.db_pool.cursor_to_dict(txn)
+
+ rows = await self.db_pool.runInteraction(
+ "get_scheduled_tasks", get_scheduled_tasks_txn
+ )
+ return [TaskSchedulerWorkerStore._convert_row_to_task(row) for row in rows]
+
+ async def insert_scheduled_task(self, task: ScheduledTask) -> None:
+ """Insert a specified `ScheduledTask` in the DB.
+
+ Args:
+ task: the `ScheduledTask` to insert
+ """
+ await self.db_pool.simple_insert(
+ "scheduled_tasks",
+ {
+ "id": task.id,
+ "action": task.action,
+ "status": task.status,
+ "timestamp": task.timestamp,
+ "resource_id": task.resource_id,
+ "params": None
+ if task.params is None
+ else json_encoder.encode(task.params),
+ "result": None
+ if task.result is None
+ else json_encoder.encode(task.result),
+ "error": task.error,
+ },
+ desc="insert_scheduled_task",
+ )
+
+ async def update_scheduled_task(
+ self,
+ id: str,
+ timestamp: int,
+ *,
+ status: Optional[TaskStatus] = None,
+ result: Optional[JsonMapping] = None,
+ error: Optional[str] = None,
+ ) -> bool:
+ """Update a scheduled task in the DB with some new value(s).
+
+ Args:
+ id: id of the `ScheduledTask` to update
+ timestamp: new timestamp of the task
+ status: new status of the task
+ result: new result of the task
+ error: new error of the task
+
+ Returns: `False` if no matching row was found, `True` otherwise
+ """
+ updatevalues: JsonDict = {"timestamp": timestamp}
+ if status is not None:
+ updatevalues["status"] = status
+ if result is not None:
+ updatevalues["result"] = json_encoder.encode(result)
+ if error is not None:
+ updatevalues["error"] = error
+ nb_rows = await self.db_pool.simple_update(
+ "scheduled_tasks",
+ {"id": id},
+ updatevalues,
+ desc="update_scheduled_task",
+ )
+ return nb_rows > 0
+
+ async def get_scheduled_task(self, id: str) -> Optional[ScheduledTask]:
+ """Get a specific `ScheduledTask` from its id.
+
+ Args:
+ id: the id of the task to retrieve
+
+ Returns: the task if available, `None` otherwise
+ """
+ row = await self.db_pool.simple_select_one(
+ table="scheduled_tasks",
+ keyvalues={"id": id},
+ retcols=(
+ "id",
+ "action",
+ "status",
+ "timestamp",
+ "resource_id",
+ "params",
+ "result",
+ "error",
+ ),
+ allow_none=True,
+ desc="get_scheduled_task",
+ )
+
+ return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
+
+ async def delete_scheduled_task(self, id: str) -> None:
+ """Delete a specific task from its id.
+
+ Args:
+ id: the id of the task to delete
+ """
+ await self.db_pool.simple_delete(
+ "scheduled_tasks",
+ keyvalues={"id": id},
+ desc="delete_scheduled_task",
+ )
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index c3bd36efc9..48e4b0ba3c 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -242,6 +242,8 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
) -> None:
# Upsert retry time interval if retry_interval is zero (i.e. we're
# resetting it) or greater than the existing retry interval.
+ # We also upsert when the new retry interval is the same as the existing one,
+ # since it will be the case when `destination_max_retry_interval` is reached.
#
# WARNING: This is executed in autocommit, so we shouldn't add any more
# SQL calls in here (without being very careful).
@@ -257,7 +259,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
WHERE
EXCLUDED.retry_interval = 0
OR destinations.retry_interval IS NULL
- OR destinations.retry_interval < EXCLUDED.retry_interval
+ OR destinations.retry_interval <= EXCLUDED.retry_interval
"""
txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
|