diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index d64910aded..08c6eabc6d 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -60,18 +60,19 @@ class _BackgroundUpdateHandler:
class _BackgroundUpdateContextManager:
- BACKGROUND_UPDATE_INTERVAL_MS = 1000
- BACKGROUND_UPDATE_DURATION_MS = 100
-
- def __init__(self, sleep: bool, clock: Clock):
+ def __init__(
+ self, sleep: bool, clock: Clock, sleep_duration_ms: int, update_duration: int
+ ):
self._sleep = sleep
self._clock = clock
+ self._sleep_duration_ms = sleep_duration_ms
+ self._update_duration_ms = update_duration
async def __aenter__(self) -> int:
if self._sleep:
- await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
+ await self._clock.sleep(self._sleep_duration_ms / 1000)
- return self.BACKGROUND_UPDATE_DURATION_MS
+ return self._update_duration_ms
async def __aexit__(self, *exc) -> None:
pass
@@ -102,10 +103,12 @@ class BackgroundUpdatePerformance:
Returns:
A duration in ms as a float
"""
- if self.avg_duration_ms == 0:
- return 0
- elif self.total_item_count == 0:
+ # We want to return None if this is the first background update item
+ if self.total_item_count == 0:
return None
+ # Avoid dividing by zero
+ elif self.avg_duration_ms == 0:
+ return 0
else:
# Use the exponential moving average so that we can adapt to
# changes in how long the update process takes.
@@ -131,9 +134,6 @@ class BackgroundUpdater:
process and autotuning the batch size.
"""
- MINIMUM_BACKGROUND_BATCH_SIZE = 1
- DEFAULT_BACKGROUND_BATCH_SIZE = 100
-
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock()
self.db_pool = database
@@ -158,6 +158,14 @@ class BackgroundUpdater:
# enable/disable background updates via the admin API.
self.enabled = True
+ self.minimum_background_batch_size = hs.config.background_updates.min_batch_size
+ self.default_background_batch_size = (
+ hs.config.background_updates.default_batch_size
+ )
+ self.update_duration_ms = hs.config.background_updates.update_duration_ms
+ self.sleep_duration_ms = hs.config.background_updates.sleep_duration_ms
+ self.sleep_enabled = hs.config.background_updates.sleep_enabled
+
def register_update_controller_callbacks(
self,
on_update: ON_UPDATE_CALLBACK,
@@ -214,7 +222,9 @@ class BackgroundUpdater:
if self._on_update_callback is not None:
return self._on_update_callback(update_name, database_name, oneshot)
- return _BackgroundUpdateContextManager(sleep, self._clock)
+ return _BackgroundUpdateContextManager(
+ sleep, self._clock, self.sleep_duration_ms, self.update_duration_ms
+ )
async def _default_batch_size(self, update_name: str, database_name: str) -> int:
"""The batch size to use for the first iteration of a new background
@@ -223,7 +233,7 @@ class BackgroundUpdater:
if self._default_batch_size_callback is not None:
return await self._default_batch_size_callback(update_name, database_name)
- return self.DEFAULT_BACKGROUND_BATCH_SIZE
+ return self.default_background_batch_size
async def _min_batch_size(self, update_name: str, database_name: str) -> int:
"""A lower bound on the batch size of a new background update.
@@ -233,7 +243,7 @@ class BackgroundUpdater:
if self._min_batch_size_callback is not None:
return await self._min_batch_size_callback(update_name, database_name)
- return self.MINIMUM_BACKGROUND_BATCH_SIZE
+ return self.minimum_background_batch_size
def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
"""Returns the current background update, if any."""
@@ -252,9 +262,12 @@ class BackgroundUpdater:
if self.enabled:
# if we start a new background update, not all updates are done.
self._all_done = False
- run_as_background_process("background_updates", self.run_background_updates)
+ sleep = self.sleep_enabled
+ run_as_background_process(
+ "background_updates", self.run_background_updates, sleep
+ )
- async def run_background_updates(self, sleep: bool = True) -> None:
+ async def run_background_updates(self, sleep: bool) -> None:
if self._running or not self.enabled:
return
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 99802228c9..367709a1a7 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -41,6 +41,7 @@ from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
+from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
+from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -286,7 +288,7 @@ class LoggingTransaction:
"""
if isinstance(self.database_engine, PostgresEngine):
- from psycopg2.extras import execute_batch # type: ignore
+ from psycopg2.extras import execute_batch
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
@@ -300,10 +302,18 @@ class LoggingTransaction:
rows (e.g. INSERTs).
"""
assert isinstance(self.database_engine, PostgresEngine)
- from psycopg2.extras import execute_values # type: ignore
+ from psycopg2.extras import execute_values
return self._do_execute(
- lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args
+ # Type ignore: mypy is unhappy because if `x` is a 5-tuple, then there will
+ # be two values for `fetch`: one given positionally, and another given
+ # as a keyword argument. We might be able to fix this by
+ # - propagating the signature of psycopg2.extras.execute_values to this
+ # function, or
+ # - changing `*args: Any` to `values: T` for some appropriate T.
+ lambda *x: execute_values(self.txn, *x, fetch=fetch), # type: ignore[misc]
+ sql,
+ *args,
)
def execute(self, sql: str, *args: Any) -> None:
@@ -732,34 +742,45 @@ class DatabasePool:
Returns:
The result of func
"""
- after_callbacks: List[_CallbackListEntry] = []
- exception_callbacks: List[_CallbackListEntry] = []
- if not current_context():
- logger.warning("Starting db txn '%s' from sentinel context", desc)
+ async def _runInteraction() -> R:
+ after_callbacks: List[_CallbackListEntry] = []
+ exception_callbacks: List[_CallbackListEntry] = []
- try:
- with opentracing.start_active_span(f"db.{desc}"):
- result = await self.runWithConnection(
- self.new_transaction,
- desc,
- after_callbacks,
- exception_callbacks,
- func,
- *args,
- db_autocommit=db_autocommit,
- isolation_level=isolation_level,
- **kwargs,
- )
+ if not current_context():
+ logger.warning("Starting db txn '%s' from sentinel context", desc)
- for after_callback, after_args, after_kwargs in after_callbacks:
- after_callback(*after_args, **after_kwargs)
- except Exception:
- for after_callback, after_args, after_kwargs in exception_callbacks:
- after_callback(*after_args, **after_kwargs)
- raise
+ try:
+ with opentracing.start_active_span(f"db.{desc}"):
+ result = await self.runWithConnection(
+ self.new_transaction,
+ desc,
+ after_callbacks,
+ exception_callbacks,
+ func,
+ *args,
+ db_autocommit=db_autocommit,
+ isolation_level=isolation_level,
+ **kwargs,
+ )
- return cast(R, result)
+ for after_callback, after_args, after_kwargs in after_callbacks:
+ after_callback(*after_args, **after_kwargs)
+
+ return cast(R, result)
+ except Exception:
+ for after_callback, after_args, after_kwargs in exception_callbacks:
+ after_callback(*after_args, **after_kwargs)
+ raise
+
+ # To handle cancellation, we ensure that `after_callback`s and
+ # `exception_callback`s are always run, since the transaction will complete
+ # on another thread regardless of cancellation.
+ #
+ # We also wait until everything above is done before releasing the
+ # `CancelledError`, so that logging contexts won't get used after they have been
+ # finished.
+ return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
async def runWithConnection(
self,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 52146aacc8..9af9f4f18e 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,7 +14,17 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ FrozenSet,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ cast,
+)
from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
@cached(max_entries=5000, iterable=True)
- async def ignored_by(self, user_id: str) -> Set[str]:
+ async def ignored_by(self, user_id: str) -> FrozenSet[str]:
"""
Get users which ignore the given user.
@@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
Return:
The user IDs which ignore the given user.
"""
- return set(
+ return frozenset(
await self.db_pool.simple_select_onecol(
table="ignored_users",
keyvalues={"ignored_user_id": user_id},
@@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
)
+ @cached(max_entries=5000, iterable=True)
+ async def ignored_users(self, user_id: str) -> FrozenSet[str]:
+ """
+ Get users which the given user ignores.
+
+ Params:
+ user_id: The user ID which is making the request.
+
+ Return:
+ The user IDs which are ignored by the given user.
+ """
+ return frozenset(
+ await self.db_pool.simple_select_onecol(
+ table="ignored_users",
+ keyvalues={"ignorer_user_id": user_id},
+ retcol="ignored_user_id",
+ desc="ignored_users",
+ )
+ )
+
def process_replication_rows(
self,
stream_name: str,
@@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
else:
currently_ignored_users = set()
+ # If the data has not changed, nothing to do.
+ if previously_ignored_users == currently_ignored_users:
+ return
+
# Delete entries which are no longer ignored.
self.db_pool.simple_delete_many_txn(
txn,
@@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+ self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
async def purge_account_data_for_user(self, user_id: str) -> None:
"""
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index c428dd5596..dd4e83a2ad 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
+ EventsStreamRow,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -31,6 +32,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
+from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_updated_caches_txn(txn):
+ def get_all_updated_caches_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
@@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_all_updated_caches", get_all_updated_caches_txn
)
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+ ) -> None:
if stream_name == EventsStream.NAME:
for row in rows:
self._process_event_stream_row(token, row)
@@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
- def _process_event_stream_row(self, token, row):
+ def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data
if row.type == EventsStreamEventRow.TypeId:
+ assert isinstance(data, EventsStreamEventRow)
self._invalidate_caches_for_event(
token,
data.event_id,
@@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
- self._curr_state_delta_stream_cache.entity_has_changed(
- row.data.room_id, token
- )
+ assert isinstance(data, EventsStreamCurrentStateRow)
+ self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
@@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_caches_for_event(
self,
- stream_ordering,
- event_id,
- room_id,
- etype,
- state_key,
- redacts,
- relates_to,
- backfilled,
- ):
+ stream_ordering: int,
+ event_id: str,
+ room_id: str,
+ etype: str,
+ state_key: Optional[str],
+ redacts: Optional[str],
+ relates_to: Optional[str],
+ backfilled: bool,
+ ) -> None:
self._invalidate_get_event_cache(event_id)
self.have_seen_event.invalidate((room_id, event_id))
@@ -186,11 +192,19 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
+ # The `_get_membership_from_event_id` is immutable, except for the
+ # case where we look up an event *before* persisting it.
+ self._get_membership_from_event_id.invalidate((event_id,))
+
if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
if redacts:
self._invalidate_get_event_cache(redacts)
+ # Caches which might leak edits must be invalidated for the event being
+ # redacted.
+ self.get_relations_for_event.invalidate((redacts,))
+ self.get_applicable_edit.invalidate((redacts,))
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
@@ -200,8 +214,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_relations_for_event.invalidate((relates_to,))
self.get_aggregation_groups_for_event.invalidate((relates_to,))
self.get_applicable_edit.invalidate((relates_to,))
+ self.get_thread_summary.invalidate((relates_to,))
+ self.get_thread_participated.invalidate((relates_to,))
- async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
+ async def invalidate_cache_and_stream(
+ self, cache_name: str, keys: Tuple[Any, ...]
+ ) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -221,7 +239,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
keys,
)
- def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+ def _invalidate_cache_and_stream(
+ self,
+ txn: LoggingTransaction,
+ cache_func: _CachedFunction,
+ keys: Tuple[Any, ...],
+ ) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -232,7 +255,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
- def _invalidate_all_cache_and_stream(self, txn, cache_func):
+ def _invalidate_all_cache_and_stream(
+ self, txn: LoggingTransaction, cache_func: _CachedFunction
+ ) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
"""
@@ -273,8 +298,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
def _send_invalidation_to_replication(
- self, txn, cache_name: str, keys: Optional[Iterable[Any]]
- ):
+ self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
+ ) -> None:
"""Notifies replication that given cache has been invalidated.
Note that this does *not* invalidate the cache locally.
@@ -309,7 +334,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
- "invalidation_ts": self.clock.time_msec(),
+ "invalidation_ts": self._clock.time_msec(),
},
)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1392363de1..b4a1b041b1 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -298,6 +298,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# This user has new messages sent to them. Query messages for them
user_ids_to_query.add(user_id)
+ if not user_ids_to_query:
+ return {}, to_stream_id
+
def get_device_messages_txn(txn: LoggingTransaction):
# Build a query to select messages from any of the given devices that
# are between the given stream id bounds.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ca2a9ba9d1..d253243125 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1518,7 +1518,7 @@ class PersistEventsStore:
)
# Remove from relations table.
- self._handle_redaction(txn, event.redacts)
+ self._handle_redact_relations(txn, event.redacts)
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
@@ -1619,9 +1619,12 @@ class PersistEventsStore:
txn.call_after(prefill)
- def _store_redaction(self, txn, event):
- # invalidate the cache for the redacted event
+ def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
+ # Invalidate the caches for the redacted event, note that these caches
+ # are also cleared as part of event replication in _invalidate_caches_for_event.
txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
+ txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
+ txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
self.db_pool.simple_upsert_txn(
txn,
@@ -1742,6 +1745,13 @@ class PersistEventsStore:
(event.state_key,),
)
+ # The `_get_membership_from_event_id` is immutable, except for the
+ # case where we look up an event *before* persisting it.
+ txn.call_after(
+ self.store._get_membership_from_event_id.invalidate,
+ (event.event_id,),
+ )
+
# We update the local_current_membership table only if the event is
# "current", i.e., its something that has just happened.
#
@@ -1811,10 +1821,11 @@ class PersistEventsStore:
if rel_type == RelationTypes.REPLACE:
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
- if rel_type == RelationTypes.THREAD:
- txn.call_after(
- self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
- )
+ if (
+ rel_type == RelationTypes.THREAD
+ or rel_type == RelationTypes.UNSTABLE_THREAD
+ ):
+ txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
# It should be safe to only invalidate the cache if the user has not
# previously participated in the thread, but that's difficult (and
# potentially error-prone) so it is always invalidated.
@@ -1943,15 +1954,43 @@ class PersistEventsStore:
txn.execute(sql, (batch_id,))
- def _handle_redaction(self, txn, redacted_event_id):
- """Handles receiving a redaction and checking whether we need to remove
- any redacted relations from the database.
+ def _handle_redact_relations(
+ self, txn: LoggingTransaction, redacted_event_id: str
+ ) -> None:
+ """Handles receiving a redaction and checking whether the redacted event
+ has any relations which must be removed from the database.
Args:
txn
- redacted_event_id (str): The event that was redacted.
+ redacted_event_id: The event that was redacted.
"""
+ # Fetch the current relation of the event being redacted.
+ redacted_relates_to = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="event_relations",
+ keyvalues={"event_id": redacted_event_id},
+ retcol="relates_to_id",
+ allow_none=True,
+ )
+ # Any relation information for the related event must be cleared.
+ if redacted_relates_to is not None:
+ self.store._invalidate_cache_and_stream(
+ txn, self.store.get_relations_for_event, (redacted_relates_to,)
+ )
+ self.store._invalidate_cache_and_stream(
+ txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,)
+ )
+ self.store._invalidate_cache_and_stream(
+ txn, self.store.get_applicable_edit, (redacted_relates_to,)
+ )
+ self.store._invalidate_cache_and_stream(
+ txn, self.store.get_thread_summary, (redacted_relates_to,)
+ )
+ self.store._invalidate_cache_and_stream(
+ txn, self.store.get_thread_participated, (redacted_relates_to,)
+ )
+
self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 26784f755e..59454a47df 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1286,7 +1286,7 @@ class EventsWorkerStore(SQLBaseStore):
)
return {eid for ((_rid, eid), have_event) in res.items() if have_event}
- @cachedList("have_seen_event", "keys")
+ @cachedList(cached_method_name="have_seen_event", list_name="keys")
async def _have_seen_events_dict(
self, keys: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], bool]:
@@ -1954,7 +1954,7 @@ class EventsWorkerStore(SQLBaseStore):
get_event_id_for_timestamp_txn,
)
- @cachedList("is_partial_state_event", list_name="event_ids")
+ @cachedList(cached_method_name="is_partial_state_event", list_name="event_ids")
async def get_partial_state_events(
self, event_ids: Collection[str]
) -> Dict[str, bool]:
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 3f6086050b..0aef121d83 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -13,13 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from typing_extensions import TypedDict
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore):
) -> List[Dict[str, Any]]:
# TODO: Pagination
- keyvalues = {"group_id": group_id}
+ keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore):
# TODO: Pagination
- def _get_rooms_in_group_txn(txn):
+ def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]:
sql = """
SELECT room_id, is_public FROM group_rooms
WHERE group_id = ?
@@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore):
* "order": int, the sort order of rooms in this category
"""
- def _get_rooms_for_summary_txn(txn):
- keyvalues = {"group_id": group_id}
+ def _get_rooms_for_summary_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+ keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_rooms_for_summary", _get_rooms_for_summary_txn
)
- async def get_group_categories(self, group_id):
+ async def get_group_categories(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list(
table="group_room_categories",
keyvalues={"group_id": group_id},
@@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows
}
- async def get_group_category(self, group_id, category_id):
+ async def get_group_category(self, group_id: str, category_id: str) -> JsonDict:
category = await self.db_pool.simple_select_one(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return category
- async def get_group_roles(self, group_id):
+ async def get_group_roles(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list(
table="group_roles",
keyvalues={"group_id": group_id},
@@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows
}
- async def get_group_role(self, group_id, role_id):
+ async def get_group_role(self, group_id: str, role_id: str) -> JsonDict:
role = await self.db_pool.simple_select_one(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_local_groups_for_room",
)
- async def get_users_for_summary_by_role(self, group_id, include_private=False):
+ async def get_users_for_summary_by_role(
+ self, group_id: str, include_private: bool = False
+ ) -> Tuple[List[JsonDict], JsonDict]:
"""Get the users and roles that should be included in a summary request
Returns:
([users], [roles])
"""
- def _get_users_for_summary_txn(txn):
- keyvalues = {"group_id": group_id}
+ def _get_users_for_summary_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], JsonDict]:
+ keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore):
allow_none=True,
)
- async def get_users_membership_info_in_group(self, group_id, user_id):
+ async def get_users_membership_info_in_group(
+ self, group_id: str, user_id: str
+ ) -> JsonDict:
"""Get a dict describing the membership of a user in a group.
Example if joined:
@@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
An empty dict if the user is not join/invite/etc
"""
- def _get_users_membership_in_group_txn(txn):
+ def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict:
row = self.db_pool.simple_select_one_txn(
txn,
table="group_users",
@@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_publicised_groups_for_user",
)
- async def get_attestations_need_renewals(self, valid_until_ms):
+ async def get_attestations_need_renewals(
+ self, valid_until_ms: int
+ ) -> List[Dict[str, Any]]:
"""Get all attestations that need to be renewed until givent time"""
- def _get_attestations_need_renewals_txn(txn):
+ def _get_attestations_need_renewals_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, Any]]:
sql = """
SELECT group_id, user_id FROM group_attestations_renewals
WHERE valid_until_ms <= ?
@@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)
- async def get_remote_attestation(self, group_id, user_id):
+ async def get_remote_attestation(
+ self, group_id: str, user_id: str
+ ) -> Optional[JsonDict]:
"""Get the attestation that proves the remote agrees that the user is
in the group.
"""
@@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_joined_groups",
)
- async def get_all_groups_for_user(self, user_id, now_token):
- def _get_all_groups_for_user_txn(txn):
+ async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
+ def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """
SELECT group_id, type, membership, u.content
FROM local_group_updates AS u
@@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_all_groups_for_user", _get_all_groups_for_user_txn
)
- async def get_groups_changes_for_user(self, user_id, from_token, to_token):
- from_token = int(from_token)
- has_changed = self._group_updates_stream_cache.has_entity_changed(
+ async def get_groups_changes_for_user(
+ self, user_id: str, from_token: int, to_token: int
+ ) -> List[JsonDict]:
+ has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined]
user_id, from_token
)
if not has_changed:
return []
- def _get_groups_changes_for_user_txn(txn):
+ def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """
SELECT group_id, membership, type, u.content
FROM local_group_updates AS u
@@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore):
"""
last_id = int(last_id)
- has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
+ has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined]
if not has_changed:
return [], current_id, False
- def _get_all_groups_changes_txn(txn):
+ def _get_all_groups_changes_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """
SELECT stream_id, group_id, user_id, type, content
FROM local_group_updates
@@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
- updates = [
- (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
- for stream_id, group_id, user_id, gtype, content_json in txn
- ]
+ updates = cast(
+ List[Tuple[int, tuple]],
+ [
+ (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
+ for stream_id, group_id, user_id, gtype, content_json in txn
+ ],
+ )
limited = False
upto_token = current_id
@@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore):
self,
group_id: str,
room_id: str,
- category_id: str,
- order: int,
+ category_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
@@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_room_to_summary_txn(
self,
- txn,
+ txn: LoggingTransaction,
group_id: str,
room_id: str,
- category_id: str,
- order: int,
+ category_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
@@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND category_id = ?
"""
txn.execute(sql, (group_id, category_id))
- (order,) = txn.fetchone()
+ (order,) = cast(Tuple[int], txn.fetchone())
if existing:
to_update = {}
@@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore):
"category_id": category_id,
"room_id": room_id,
},
- values=to_update,
+ updatevalues=to_update,
)
else:
if is_public is None:
@@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_room_from_summary(
- self, group_id: str, room_id: str, category_id: str
+ self, group_id: str, room_id: str, category_id: Optional[str]
) -> int:
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
@@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool],
) -> None:
"""Add/update room category for group"""
- insertion_values = {}
- update_values = {"category_id": category_id} # This cannot be empty
+ insertion_values: JsonDict = {}
+ update_values: JsonDict = {"category_id": category_id} # This cannot be empty
if profile is None:
insertion_values["profile"] = "{}"
@@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool],
) -> None:
"""Add/remove user role"""
- insertion_values = {}
- update_values = {"role_id": role_id} # This cannot be empty
+ insertion_values: JsonDict = {}
+ update_values: JsonDict = {"role_id": role_id} # This cannot be empty
if profile is None:
insertion_values["profile"] = "{}"
@@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore):
self,
group_id: str,
user_id: str,
- role_id: str,
- order: int,
+ role_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) user's entry in summary.
@@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_user_to_summary_txn(
self,
- txn,
+ txn: LoggingTransaction,
group_id: str,
user_id: str,
- role_id: str,
- order: int,
+ role_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
- ):
+ ) -> None:
"""Add (or update) user's entry in summary.
Args:
@@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND role_id = ?
"""
txn.execute(sql, (group_id, role_id))
- (order,) = txn.fetchone()
+ (order,) = cast(Tuple[int], txn.fetchone())
if existing:
to_update = {}
@@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore):
"role_id": role_id,
"user_id": user_id,
},
- values=to_update,
+ updatevalues=to_update,
)
else:
if is_public is None:
@@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_user_from_summary(
- self, group_id: str, user_id: str, role_id: str
+ self, group_id: str, user_id: str, role_id: Optional[str]
) -> int:
if role_id is None:
role_id = _DEFAULT_ROLE_ID
@@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore):
Optional if the user and group are on the same server
"""
- def _add_user_to_group_txn(txn):
+ def _add_user_to_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_insert_txn(
txn,
table="group_users",
@@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore):
await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
- def _remove_user_from_group_txn(txn):
+ def _remove_user_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="group_users",
@@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
- def _remove_room_from_group_txn(txn):
+ def _remove_room_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="group_rooms",
@@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore):
content = content or {}
- def _register_user_group_membership_txn(txn, next_id):
+ def _register_user_group_membership_txn(
+ txn: LoggingTransaction, next_id: int
+ ) -> int:
# TODO: Upsert?
self.db_pool.simple_delete_txn(
txn,
@@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore):
),
},
)
- self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
+ self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined]
# TODO: Insert profile to ensure it comes down stream if its a join.
@@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- async with self._group_updates_id_gen.get_next() as next_id:
+ async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined]
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
@@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore):
return res
async def create_group(
- self, group_id, user_id, name, avatar_url, short_description, long_description
+ self,
+ group_id: str,
+ user_id: str,
+ name: str,
+ avatar_url: str,
+ short_description: str,
+ long_description: str,
) -> None:
await self.db_pool.simple_insert(
table="groups",
@@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore):
desc="create_group",
)
- async def update_group_profile(self, group_id, profile):
+ async def update_group_profile(self, group_id: str, profile: JsonDict) -> None:
await self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
@@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_attestation_renewal",
)
- def get_group_stream_token(self):
- return self._group_updates_id_gen.get_current_token()
+ def get_group_stream_token(self) -> int:
+ return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined]
async def delete_group(self, group_id: str) -> None:
"""Deletes a group fully from the database.
@@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore):
group_id: The group ID to delete.
"""
- def _delete_group_txn(txn):
+ def _delete_group_txn(txn: LoggingTransaction) -> None:
tables = [
"groups",
"group_users",
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e9a0cdc6be..216622964a 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
+ LoggingTransaction,
make_in_list_sql_clause,
)
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
from synapse.util.caches.descriptors import cached
from synapse.util.threepids import canonicalise_email
@@ -56,7 +58,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
Number of current monthly active users
"""
- def _count_users(txn):
+ def _count_users(txn: LoggingTransaction) -> int:
# Exclude app service users
sql = """
SELECT COUNT(*)
@@ -66,7 +68,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
"""
txn.execute(sql)
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_users", _count_users)
@@ -84,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
- def _count_users_by_service(txn):
+ def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]:
sql = """
SELECT COALESCE(appservice_id, 'native'), COUNT(*)
FROM monthly_active_users
@@ -93,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
txn.execute(sql)
- result = txn.fetchall()
+ result = cast(List[Tuple[str, int]], txn.fetchall())
return dict(result)
return await self.db_pool.runInteraction(
@@ -141,12 +143,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
)
@wrap_as_background_process("reap_monthly_active_users")
- async def reap_monthly_active_users(self):
+ async def reap_monthly_active_users(self) -> None:
"""Cleans out monthly active user table to ensure that no stale
entries exist.
"""
- def _reap_users(txn, reserved_users):
+ def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None:
"""
Args:
reserved_users (tuple): reserved users to preserve
@@ -210,10 +212,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
# is racy.
# Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant
- self._invalidate_all_cache_and_stream(
+ self._invalidate_all_cache_and_stream( # type: ignore[attr-defined]
txn, self.user_last_seen_monthly_active
)
- self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
+ self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined]
reserved_users = await self.get_registered_reserved_users()
await self.db_pool.runInteraction(
@@ -221,7 +223,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
)
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -242,13 +244,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
- def _initialise_reserved_users(self, txn, threepids):
+ def _initialise_reserved_users(
+ self, txn: LoggingTransaction, threepids: List[dict]
+ ) -> None:
"""Ensures that reserved threepids are accounted for in the MAU table, should
be called on start up.
Args:
- txn (cursor):
- threepids (list[dict]): List of threepid dicts to reserve
+ txn:
+ threepids: List of threepid dicts to reserve
"""
# XXX what is this function trying to achieve? It upserts into
@@ -299,7 +303,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
- def upsert_monthly_active_user_txn(self, txn, user_id):
+ def upsert_monthly_active_user_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> None:
"""Updates or inserts monthly active user member
We consciously do not call is_support_txn from this method because it
@@ -336,7 +342,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
txn, self.user_last_seen_monthly_active, (user_id,)
)
- async def populate_monthly_active_users(self, user_id):
+ async def populate_monthly_active_users(self, user_id: str) -> None:
"""Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables
@@ -345,7 +351,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"""
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
- is_guest = await self.is_guest(user_id)
+ is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
if is_guest:
return
is_trial = await self.is_trial_user(user_id)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index dc6665237a..a698d10cc5 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -48,8 +48,6 @@ class ExternalIDReuseException(Exception):
"""Exception if writing an external id for a user fails,
because this external id is given to an other user."""
- pass
-
@attr.s(frozen=True, slots=True, auto_attribs=True)
class TokenLookupResult:
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 36aa1092f6..b2295fd51f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -27,7 +27,6 @@ from typing import (
)
import attr
-from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
@@ -41,45 +40,15 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.server import HomeServer
- from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _ThreadAggregation:
- # The latest event in the thread.
- latest_event: EventBase
- # The latest edit to the latest event in the thread.
- latest_edit: Optional[EventBase]
- # The total number of events in the thread.
- count: int
- # True if the current user has sent an event to the thread.
- current_user_participated: bool
-
-
-@attr.s(slots=True, auto_attribs=True)
-class BundledAggregations:
- """
- The bundled aggregations for an event.
-
- Some values require additional processing during serialization.
- """
-
- annotations: Optional[JsonDict] = None
- references: Optional[JsonDict] = None
- replace: Optional[EventBase] = None
- thread: Optional[_ThreadAggregation] = None
-
- def __bool__(self) -> bool:
- return bool(self.annotations or self.references or self.replace or self.thread)
-
-
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -91,10 +60,11 @@ class RelationsWorkerStore(SQLBaseStore):
self._msc3440_enabled = hs.config.experimental.msc3440_enabled
- @cached(tree=True)
+ @cached(uncached_args=("event",), tree=True)
async def get_relations_for_event(
self,
event_id: str,
+ event: EventBase,
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
@@ -108,6 +78,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
+ event: The matching EventBase to event_id.
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
@@ -122,9 +93,13 @@ class RelationsWorkerStore(SQLBaseStore):
List of event IDs that match relations requested. The rows are of
the form `{"event_id": "..."}`.
"""
+ # We don't use `event_id`, it's there so that we can cache based on
+ # it. The `event_id` must match the `event.event_id`.
+ assert event.event_id == event_id
where_clause = ["relates_to_id = ?", "room_id = ?"]
- where_args: List[Union[str, int]] = [event_id, room_id]
+ where_args: List[Union[str, int]] = [event.event_id, room_id]
+ is_redacted = event.internal_metadata.is_redacted()
if relation_type is not None:
where_clause.append("relation_type = ?")
@@ -157,7 +132,7 @@ class RelationsWorkerStore(SQLBaseStore):
order = "ASC"
sql = """
- SELECT event_id, topological_ordering, stream_ordering
+ SELECT event_id, relation_type, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE %s
@@ -178,9 +153,12 @@ class RelationsWorkerStore(SQLBaseStore):
last_stream_id = None
events = []
for row in txn:
- events.append({"event_id": row[0]})
- last_topo_id = row[1]
- last_stream_id = row[2]
+ # Do not include edits for redacted events as they leak event
+ # content.
+ if not is_redacted or row[1] != RelationTypes.REPLACE:
+ events.append({"event_id": row[0]})
+ last_topo_id = row[2]
+ last_stream_id = row[3]
# If there are more events, generate the next pagination key.
next_token = None
@@ -375,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
- async def _get_applicable_edits(
+ async def get_applicable_edits(
self, event_ids: Collection[str]
) -> Dict[str, Optional[EventBase]]:
"""Get the most recent edit (if any) that has happened for the given
@@ -464,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
- async def _get_thread_summaries(
+ async def get_thread_summaries(
self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
"""Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
@@ -499,7 +477,7 @@ class RelationsWorkerStore(SQLBaseStore):
AND parent.room_id = child.room_id
WHERE
%s
- AND relation_type = ?
+ AND %s
ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC
"""
else:
@@ -514,16 +492,22 @@ class RelationsWorkerStore(SQLBaseStore):
AND parent.room_id = child.room_id
WHERE
%s
- AND relation_type = ?
+ AND %s
ORDER BY child.topological_ordering DESC, child.stream_ordering DESC
"""
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", event_ids
)
- args.append(RelationTypes.THREAD)
- txn.execute(sql % (clause,), args)
+ if self._msc3440_enabled:
+ relations_clause = "(relation_type = ? OR relation_type = ?)"
+ args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
+ else:
+ relations_clause = "relation_type = ?"
+ args.append(RelationTypes.THREAD)
+
+ txn.execute(sql % (clause, relations_clause), args)
latest_event_ids = {}
for parent_event_id, child_event_id in txn:
# Only consider the latest threaded reply (by topological ordering).
@@ -543,7 +527,7 @@ class RelationsWorkerStore(SQLBaseStore):
AND parent.room_id = child.room_id
WHERE
%s
- AND relation_type = ?
+ AND %s
GROUP BY parent.event_id
"""
@@ -552,9 +536,15 @@ class RelationsWorkerStore(SQLBaseStore):
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", latest_event_ids.keys()
)
- args.append(RelationTypes.THREAD)
- txn.execute(sql % (clause,), args)
+ if self._msc3440_enabled:
+ relations_clause = "(relation_type = ? OR relation_type = ?)"
+ args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
+ else:
+ relations_clause = "relation_type = ?"
+ args.append(RelationTypes.THREAD)
+
+ txn.execute(sql % (clause, relations_clause), args)
counts = dict(cast(List[Tuple[str, int]], txn.fetchall()))
return counts, latest_event_ids
@@ -566,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
# Check to see if any of those events are edited.
- latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+ latest_edits = await self.get_applicable_edits(latest_event_ids.values())
# Map to the event IDs to the thread summary.
#
@@ -589,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
- async def _get_threads_participated(
+ async def get_threads_participated(
self, event_ids: Collection[str], user_id: str
) -> Dict[str, bool]:
"""Get whether the requesting user participated in the given threads.
@@ -617,16 +607,24 @@ class RelationsWorkerStore(SQLBaseStore):
AND parent.room_id = child.room_id
WHERE
%s
- AND relation_type = ?
+ AND %s
AND child.sender = ?
"""
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", event_ids
)
- args.extend((RelationTypes.THREAD, user_id))
- txn.execute(sql % (clause,), args)
+ if self._msc3440_enabled:
+ relations_clause = "(relation_type = ? OR relation_type = ?)"
+ args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
+ else:
+ relations_clause = "relation_type = ?"
+ args.append(RelationTypes.THREAD)
+
+ args.append(user_id)
+
+ txn.execute(sql % (clause, relations_clause), args)
return {row[0] for row in txn.fetchall()}
participated_threads = await self.db_pool.runInteraction(
@@ -737,122 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
- async def _get_bundled_aggregation_for_event(
- self, event: EventBase, user_id: str
- ) -> Optional[BundledAggregations]:
- """Generate bundled aggregations for an event.
-
- Note that this does not use a cache, but depends on cached methods.
-
- Args:
- event: The event to calculate bundled aggregations for.
- user_id: The user requesting the bundled aggregations.
-
- Returns:
- The bundled aggregations for an event, if bundled aggregations are
- enabled and the event can have bundled aggregations.
- """
-
- # Do not bundle aggregations for an event which represents an edit or an
- # annotation. It does not make sense for them to have related events.
- relates_to = event.content.get("m.relates_to")
- if isinstance(relates_to, (dict, frozendict)):
- relation_type = relates_to.get("rel_type")
- if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
- return None
-
- event_id = event.event_id
- room_id = event.room_id
-
- # The bundled aggregations to include, a mapping of relation type to a
- # type-specific value. Some types include the direct return type here
- # while others need more processing during serialization.
- aggregations = BundledAggregations()
-
- annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
- if annotations.chunk:
- aggregations.annotations = await annotations.to_dict(
- cast("DataStore", self)
- )
-
- references = await self.get_relations_for_event(
- event_id, room_id, RelationTypes.REFERENCE, direction="f"
- )
- if references.chunk:
- aggregations.references = await references.to_dict(cast("DataStore", self))
-
- # Store the bundled aggregations in the event metadata for later use.
- return aggregations
-
- async def get_bundled_aggregations(
- self, events: Iterable[EventBase], user_id: str
- ) -> Dict[str, BundledAggregations]:
- """Generate bundled aggregations for events.
-
- Args:
- events: The iterable of events to calculate bundled aggregations for.
- user_id: The user requesting the bundled aggregations.
-
- Returns:
- A map of event ID to the bundled aggregation for the event. Not all
- events may have bundled aggregations in the results.
- """
- # The already processed event IDs. Tracked separately from the result
- # since the result omits events which do not have bundled aggregations.
- seen_event_ids = set()
-
- # State events and redacted events do not get bundled aggregations.
- events = [
- event
- for event in events
- if not event.is_state() and not event.internal_metadata.is_redacted()
- ]
-
- # event ID -> bundled aggregation in non-serialized form.
- results: Dict[str, BundledAggregations] = {}
-
- # Fetch other relations per event.
- for event in events:
- # De-duplicate events by ID to handle the same event requested multiple
- # times. The caches that _get_bundled_aggregation_for_event use should
- # capture this, but best to reduce work.
- if event.event_id in seen_event_ids:
- continue
- seen_event_ids.add(event.event_id)
-
- event_result = await self._get_bundled_aggregation_for_event(event, user_id)
- if event_result:
- results[event.event_id] = event_result
-
- # Fetch any edits.
- edits = await self._get_applicable_edits(seen_event_ids)
- for event_id, edit in edits.items():
- results.setdefault(event_id, BundledAggregations()).replace = edit
-
- # Fetch thread summaries.
- if self._msc3440_enabled:
- summaries = await self._get_thread_summaries(seen_event_ids)
- # Only fetch participated for a limited selection based on what had
- # summaries.
- participated = await self._get_threads_participated(
- summaries.keys(), user_id
- )
- for event_id, summary in summaries.items():
- if summary:
- thread_count, latest_thread_event, edit = summary
- results.setdefault(
- event_id, BundledAggregations()
- ).thread = _ThreadAggregation(
- latest_event=latest_thread_event,
- latest_edit=edit,
- count=thread_count,
- # If there's a thread summary it must also exist in the
- # participated dictionary.
- current_user_participated=participated[event_id],
- )
-
- return results
-
class RelationsStore(RelationsWorkerStore):
pass
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e48ec5f495..3248da5356 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -46,7 +46,7 @@ from synapse.storage.roommember import (
ProfileInfo,
RoomsForUser,
)
-from synapse.types import PersistedEventPosition, StateMap, get_domain_from_id
+from synapse.types import PersistedEventPosition, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class EventIdMembership:
+ """Returned by `get_membership_from_event_ids`"""
+
+ user_id: str
+ membership: str
+
+
class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(
self,
@@ -273,7 +281,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (room_id,))
res = {}
for count, membership in txn:
- summary = res.setdefault(membership, MemberSummary([], count))
+ res.setdefault(membership, MemberSummary([], count))
# we order by membership and then fairly arbitrarily by event_id so
# heroes are consistent
@@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
- desc="_get_membership_from_event_ids",
+ desc="_get_joined_profiles_from_event_ids",
)
return {
@@ -839,18 +847,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
with Measure(self._clock, "get_joined_hosts"):
return await self._get_joined_hosts(
- room_id, state_group, state_entry.state, state_entry=state_entry
+ room_id, state_group, state_entry=state_entry
)
@cached(num_args=2, max_entries=10000, iterable=True)
async def _get_joined_hosts(
- self,
- room_id: str,
- state_group: int,
- current_state_ids: StateMap[str],
- state_entry: "_StateCacheEntry",
+ self, room_id: str, state_group: int, state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
- # We don't use `state_group`, its there so that we can cache based on
+ # We don't use `state_group`, it's there so that we can cache based on
# it. However, its important that its never None, since two
# current_state's with a state_group of None are likely to be different.
#
@@ -1004,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
+ @cached(max_entries=5000)
+ async def _get_membership_from_event_id(
+ self, member_event_id: str
+ ) -> Optional[EventIdMembership]:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
+ )
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
- ) -> List[dict]:
- """Get user_id and membership of a set of event IDs."""
+ ) -> Dict[str, Optional[EventIdMembership]]:
+ """Get user_id and membership of a set of event IDs.
- return await self.db_pool.simple_select_many_batch(
+ Returns:
+ Mapping from event ID to `EventIdMembership` if the event is a
+ membership event, otherwise the value is None.
+ """
+
+ rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
@@ -1019,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
desc="get_membership_from_event_ids",
)
+ return {
+ row["event_id"]: EventIdMembership(
+ membership=row["membership"], user_id=row["user_id"]
+ )
+ for row in rows
+ }
+
async def is_local_host_in_room_ignoring_users(
self, room_id: str, ignore_users: Collection[str]
) -> bool:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index e23b119072..c5e9010c83 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -125,9 +125,6 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
):
super().__init__(database, db_conn, hs)
- if not hs.config.server.enable_search:
- return
-
self.db_pool.updates.register_background_update_handler(
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)
@@ -243,9 +240,13 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
return len(event_search_rows)
- result = await self.db_pool.runInteraction(
- self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
- )
+ if self.hs.config.server.enable_search:
+ result = await self.db_pool.runInteraction(
+ self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
+ )
+ else:
+ # Don't index anything if search is not enabled.
+ result = 0
if not result:
await self.db_pool.updates._end_background_update(
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index a898f847e7..39e1efe373 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -325,21 +325,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
args.extend(event_filter.labels)
# Filter on relation_senders / relation types from the joined tables.
- if event_filter.relation_senders:
+ if event_filter.related_by_senders:
clauses.append(
"(%s)"
% " OR ".join(
- "related_event.sender = ?" for _ in event_filter.relation_senders
+ "related_event.sender = ?" for _ in event_filter.related_by_senders
)
)
- args.extend(event_filter.relation_senders)
+ args.extend(event_filter.related_by_senders)
- if event_filter.relation_types:
+ if event_filter.related_by_rel_types:
clauses.append(
"(%s)"
- % " OR ".join("relation_type = ?" for _ in event_filter.relation_types)
+ % " OR ".join(
+ "relation_type = ?" for _ in event_filter.related_by_rel_types
+ )
)
- args.extend(event_filter.relation_types)
+ args.extend(event_filter.related_by_rel_types)
return " AND ".join(clauses), args
@@ -1203,7 +1205,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# If there is a filter on relation_senders and relation_types join to the
# relations table.
if event_filter and (
- event_filter.relation_senders or event_filter.relation_types
+ event_filter.related_by_senders or event_filter.related_by_rel_types
):
# Filtering by relations could cause the same event to appear multiple
# times (since there's no limit on the number of relations to an event).
@@ -1211,7 +1213,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
join_clause += """
LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id)
"""
- if event_filter.relation_senders:
+ if event_filter.related_by_senders:
join_clause += """
LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
"""
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index e7fddd2426..0595df01d3 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -26,6 +26,8 @@ from typing import (
cast,
)
+from typing_extensions import TypedDict
+
from synapse.api.errors import StoreError
if TYPE_CHECKING:
@@ -40,7 +42,12 @@ from synapse.storage.database import (
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
+from synapse.types import (
+ JsonDict,
+ UserProfile,
+ get_domain_from_id,
+ get_localpart_from_id,
+)
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
+class SearchResult(TypedDict):
+ limited: bool
+ results: List[UserProfile]
+
+
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
@@ -718,7 +730,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- async def get_shared_rooms_for_users(
+ async def get_mutual_rooms_for_users(
self, user_id: str, other_user_id: str
) -> Set[str]:
"""
@@ -732,7 +744,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
A set of room ID's that the users share.
"""
- def _get_shared_rooms_for_users_txn(
+ def _get_mutual_rooms_for_users_txn(
txn: LoggingTransaction,
) -> List[Dict[str, str]]:
txn.execute(
@@ -756,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return rows
rows = await self.db_pool.runInteraction(
- "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
+ "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn
)
return {row["room_id"] for row in rows}
@@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
async def search_user_dir(
self, user_id: str, search_term: str, limit: int
- ) -> JsonDict:
+ ) -> SearchResult:
"""Searches for users in directory
Returns:
@@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
- results = await self.db_pool.execute(
- "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+ results = cast(
+ List[UserProfile],
+ await self.db_pool.execute(
+ "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+ ),
)
limited = len(results) > limit
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 9abc02046e..afb7d5054d 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -27,7 +27,7 @@ def create_engine(database_config) -> BaseDatabaseEngine:
if name == "psycopg2":
# Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
- import psycopg2 # type: ignore
+ import psycopg2
return PostgresEngine(psycopg2, database_config)
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 808342fafb..e8d29e2870 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -47,17 +47,26 @@ class PostgresEngine(BaseDatabaseEngine):
self.default_isolation_level = (
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
+ self.config = database_config
@property
def single_threaded(self) -> bool:
return False
+ def get_db_locale(self, txn):
+ txn.execute(
+ "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+ )
+ collation, ctype = txn.fetchone()
+ return collation, ctype
+
def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
self._version = db_conn.server_version
+ allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
# Are we on a supported PostgreSQL version?
if not allow_outdated_version and self._version < 100000:
@@ -72,33 +81,39 @@ class PostgresEngine(BaseDatabaseEngine):
"See docs/postgres.md for more information." % (rows[0][0],)
)
- txn.execute(
- "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
- )
- collation, ctype = txn.fetchone()
+ collation, ctype = self.get_db_locale(txn)
if collation != "C":
logger.warning(
- "Database has incorrect collation of %r. Should be 'C'\n"
- "See docs/postgres.md for more information.",
+ "Database has incorrect collation of %r. Should be 'C'",
collation,
)
+ if not allow_unsafe_locale:
+ raise IncorrectDatabaseSetup(
+ "Database has incorrect collation of %r. Should be 'C'\n"
+ "See docs/postgres.md for more information. You can override this check by"
+ "setting 'allow_unsafe_locale' to true in the database config.",
+ collation,
+ )
if ctype != "C":
- logger.warning(
- "Database has incorrect ctype of %r. Should be 'C'\n"
- "See docs/postgres.md for more information.",
- ctype,
- )
+ if not allow_unsafe_locale:
+ logger.warning(
+ "Database has incorrect ctype of %r. Should be 'C'",
+ ctype,
+ )
+ raise IncorrectDatabaseSetup(
+ "Database has incorrect ctype of %r. Should be 'C'\n"
+ "See docs/postgres.md for more information. You can override this check by"
+ "setting 'allow_unsafe_locale' to true in the database config.",
+ ctype,
+ )
def check_new_database(self, txn):
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
- txn.execute(
- "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
- )
- collation, ctype = txn.fetchone()
+ collation, ctype = self.get_db_locale(txn)
errors = []
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 7d543fdbe0..b402922817 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -1023,8 +1023,13 @@ class EventsPersistenceStorage:
# Check if any of the changes that we don't have events for are joins.
if events_to_check:
- rows = await self.main_store.get_membership_from_event_ids(events_to_check)
- is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
+ members = await self.main_store.get_membership_from_event_ids(
+ events_to_check
+ )
+ is_still_joined = any(
+ member and member.membership == Membership.JOIN
+ for member in members.values()
+ )
if is_still_joined:
return True
@@ -1060,9 +1065,11 @@ class EventsPersistenceStorage:
), event_id in current_state.items()
if typ == EventTypes.Member and not self.is_mine_id(state_key)
]
- rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
+ members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
potentially_left_users.update(
- row["user_id"] for row in rows if row["membership"] == Membership.JOIN
+ member.user_id
+ for member in members.values()
+ if member and member.membership == Membership.JOIN
)
return False
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index 36ca2b8273..fba270150b 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -55,37 +55,6 @@ class PaginationChunk:
@attr.s(frozen=True, slots=True, auto_attribs=True)
-class RelationPaginationToken:
- """Pagination token for relation pagination API.
-
- As the results are in topological order, we can use the
- `topological_ordering` and `stream_ordering` fields of the events at the
- boundaries of the chunk as pagination tokens.
-
- Attributes:
- topological: The topological ordering of the boundary event
- stream: The stream ordering of the boundary event.
- """
-
- topological: int
- stream: int
-
- @staticmethod
- def from_string(string: str) -> "RelationPaginationToken":
- try:
- t, s = string.split("-")
- return RelationPaginationToken(int(t), int(s))
- except ValueError:
- raise SynapseError(400, "Invalid relation pagination token")
-
- async def to_string(self, store: "DataStore") -> str:
- return "%d-%d" % (self.topological, self.stream)
-
- def as_tuple(self) -> Tuple[Any, ...]:
- return attr.astuple(self)
-
-
-@attr.s(frozen=True, slots=True, auto_attribs=True)
class AggregationPaginationToken:
"""Pagination token for relation aggregation pagination API.
diff --git a/synapse/storage/schema/main/delta/30/as_users.py b/synapse/storage/schema/main/delta/30/as_users.py
index 22a7901e15..4b4b166e37 100644
--- a/synapse/storage/schema/main/delta/30/as_users.py
+++ b/synapse/storage/schema/main/delta/30/as_users.py
@@ -36,7 +36,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
config_files = config.appservice.app_service_config_files
except AttributeError:
logger.warning("Could not get app_service_config_files from config")
- pass
appservices = load_appservices(config.server.server_name, config_files)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index e79ecf64a0..86f1a5373b 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -561,7 +561,7 @@ class StateGroupStorage:
return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids(
- self, _room_id: str, event_ids: Iterable[str]
+ self, _room_id: str, event_ids: Collection[str]
) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
@@ -596,7 +596,7 @@ class StateGroupStorage:
return group_to_state[state_group]
async def get_state_groups(
- self, room_id: str, event_ids: Iterable[str]
+ self, room_id: str, event_ids: Collection[str]
) -> Dict[int, List[EventBase]]:
"""Get the state groups for the given list of event_ids
@@ -648,7 +648,7 @@ class StateGroupStorage:
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
async def get_state_for_events(
- self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
+ self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
@@ -684,7 +684,7 @@ class StateGroupStorage:
return {event: event_to_state[event] for event in event_ids}
async def get_state_ids_for_events(
- self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
+ self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
|