diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 11d9d16c19..4dccbb732a 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -24,9 +24,9 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.stats import UserSortOrder
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -45,7 +45,6 @@ from .event_push_actions import EventPushActionsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
from .events_forward_extremities import EventForwardExtremitiesStore
from .filtering import FilteringStore
-from .group_server import GroupServerStore
from .keys import KeyStore
from .lock import LockStore
from .media_repository import MediaRepositoryStore
@@ -88,7 +87,6 @@ class DataStore(
RoomStore,
RoomBatchStore,
RegistrationStore,
- StreamWorkerStore,
ProfileStore,
PresenceStore,
TransactionWorkerStore,
@@ -105,19 +103,20 @@ class DataStore(
PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore,
+ EventPushActionsStore,
+ ServerMetricsStore,
ReceiptsStore,
EndToEndKeyStore,
EndToEndRoomKeyStore,
SearchStore,
TagsStore,
AccountDataStore,
- EventPushActionsStore,
+ StreamWorkerStore,
OpenIdStore,
ClientIpWorkerStore,
DeviceStore,
DeviceInboxStore,
UserDirectoryStore,
- GroupServerStore,
UserErasureStore,
MonthlyActiveUsersWorkerStore,
StatsStore,
@@ -126,7 +125,6 @@ class DataStore(
UIAuthStore,
EventForwardExtremitiesStore,
CacheInvalidationWorkerStore,
- ServerMetricsStore,
LockStore,
SessionStore,
):
@@ -151,31 +149,6 @@ class DataStore(
],
)
- self._cache_id_gen: Optional[MultiWriterIdGenerator]
- if isinstance(self.database_engine, PostgresEngine):
- # We set the `writers` to an empty list here as we don't care about
- # missing updates over restarts, as we'll not have anything in our
- # caches to invalidate. (This reduces the amount of writes to the DB
- # that happen).
- self._cache_id_gen = MultiWriterIdGenerator(
- db_conn,
- database,
- stream_name="caches",
- instance_name=hs.get_instance_name(),
- tables=[
- (
- "cache_invalidation_stream_by_instance",
- "instance_name",
- "stream_id",
- )
- ],
- sequence_name="cache_invalidation_stream_seq",
- writers=[],
- )
-
- else:
- self._cache_id_gen = None
-
super().__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
@@ -197,6 +170,7 @@ class DataStore(
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
def get_device_stream_token(self) -> int:
+ # TODO: shouldn't this be moved to `DeviceWorkerStore`?
return self._device_list_id_gen.get_current_token()
async def get_users(self) -> List[JsonDict]:
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 9af9f4f18e..c38b8a9e5a 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -650,9 +650,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
txn, self.get_account_data_for_room, (user_id,)
)
self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
- self._invalidate_cache_and_stream(
- txn, self.get_push_rules_enabled_for_user, (user_id,)
- )
# This user might be contained in the ignored_by cache for other users,
# so we have to invalidate it all.
self._invalidate_all_cache_and_stream(txn, self.ignored_by)
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index e284454b66..64b70a7b28 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -371,52 +371,30 @@ class ApplicationServiceTransactionWorkerStore(
device_list_summary=DeviceListUpdates(),
)
- async def set_appservice_last_pos(self, pos: int) -> None:
- def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
- txn.execute(
- "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
- )
+ async def get_appservice_last_pos(self) -> int:
+ """
+ Get the last stream ordering position for the appservice process.
+ """
- await self.db_pool.runInteraction(
- "set_appservice_last_pos", set_appservice_last_pos_txn
+ return await self.db_pool.simple_select_one_onecol(
+ table="appservice_stream_position",
+ retcol="stream_ordering",
+ keyvalues={},
+ desc="get_appservice_last_pos",
)
- async def get_new_events_for_appservice(
- self, current_id: int, limit: int
- ) -> Tuple[int, List[EventBase]]:
- """Get all new events for an appservice"""
-
- def get_new_events_for_appservice_txn(
- txn: LoggingTransaction,
- ) -> Tuple[int, List[str]]:
- sql = (
- "SELECT e.stream_ordering, e.event_id"
- " FROM events AS e"
- " WHERE"
- " (SELECT stream_ordering FROM appservice_stream_position)"
- " < e.stream_ordering"
- " AND e.stream_ordering <= ?"
- " ORDER BY e.stream_ordering ASC"
- " LIMIT ?"
- )
-
- txn.execute(sql, (current_id, limit))
- rows = txn.fetchall()
-
- upper_bound = current_id
- if len(rows) == limit:
- upper_bound = rows[-1][0]
-
- return upper_bound, [row[1] for row in rows]
+ async def set_appservice_last_pos(self, pos: int) -> None:
+ """
+ Set the last stream ordering position for the appservice process.
+ """
- upper_bound, event_ids = await self.db_pool.runInteraction(
- "get_new_events_for_appservice", get_new_events_for_appservice_txn
+ await self.db_pool.simple_update_one(
+ table="appservice_stream_position",
+ keyvalues={},
+ updatevalues={"stream_ordering": pos},
+ desc="set_appservice_last_pos",
)
- events = await self.get_events_as_list(event_ids, get_prev_content=True)
-
- return upper_bound, events
-
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 1653a6a9b6..12e9a42382 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -32,6 +32,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter
@@ -65,6 +66,31 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
psql_only=True, # The table is only on postgres DBs.
)
+ self._cache_id_gen: Optional[MultiWriterIdGenerator]
+ if isinstance(self.database_engine, PostgresEngine):
+ # We set the `writers` to an empty list here as we don't care about
+ # missing updates over restarts, as we'll not have anything in our
+ # caches to invalidate. (This reduces the amount of writes to the DB
+ # that happen).
+ self._cache_id_gen = MultiWriterIdGenerator(
+ db_conn,
+ database,
+ stream_name="caches",
+ instance_name=hs.get_instance_name(),
+ tables=[
+ (
+ "cache_invalidation_stream_by_instance",
+ "instance_name",
+ "stream_id",
+ )
+ ],
+ sequence_name="cache_invalidation_stream_seq",
+ writers=[],
+ )
+
+ else:
+ self._cache_id_gen = None
+
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
@@ -193,7 +219,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
relates_to: Optional[str],
backfilled: bool,
) -> None:
- self._invalidate_get_event_cache(event_id)
+ # This invalidates any local in-memory cached event objects, the original
+ # process triggering the invalidation is responsible for clearing any external
+ # cached objects.
+ self._invalidate_local_get_event_cache(event_id)
self.have_seen_event.invalidate((room_id, event_id))
self.get_latest_event_ids_in_room.invalidate((room_id,))
@@ -208,7 +237,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
if redacts:
- self._invalidate_get_event_cache(redacts)
+ self._invalidate_local_get_event_cache(redacts)
# Caches which might leak edits must be invalidated for the event being
# redacted.
self.get_relations_for_event.invalidate((redacts,))
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index fd3fc298b3..58177ecec1 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
# changed its content in the database. We can't call
# self._invalidate_cache_and_stream because self.get_event_cache isn't of the
# right type.
- txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+ self.invalidate_get_event_cache_after_txn(txn, event.event_id)
# Send that invalidation to replication so that other workers also invalidate
# the event cache.
self._send_invalidation_to_replication(
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 599b418383..73c95ffb6f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -436,7 +436,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
(user_id, device_id), None
)
- set_tag("last_deleted_stream_id", last_deleted_stream_id)
+ set_tag("last_deleted_stream_id", str(last_deleted_stream_id))
if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
@@ -834,8 +834,6 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
- REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
- REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
def __init__(
@@ -857,15 +855,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
- # Used to be a background update that deletes all device_inboxes for deleted
- # devices.
- self.db_pool.updates.register_noop_background_update(
- self.REMOVE_DELETED_DEVICES
- )
- # Used to be a background update that deletes all device_inboxes for hidden
- # devices.
- self.db_pool.updates.register_noop_background_update(self.REMOVE_HIDDEN_DEVICES)
-
self.db_pool.updates.register_background_update_handler(
self.REMOVE_DEAD_DEVICES_FROM_INBOX,
self._remove_dead_devices_from_device_inbox,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 71e7863dd8..ca0fe8c4be 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -28,6 +28,8 @@ from typing import (
cast,
)
+from typing_extensions import Literal
+
from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@@ -44,6 +46,8 @@ from synapse.storage.database import (
LoggingTransaction,
make_tuple_comparison_clause,
)
+from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.types import Cursor
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -65,7 +69,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
-class DeviceWorkerStore(SQLBaseStore):
+class DeviceWorkerStore(EndToEndKeyWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -74,7 +78,9 @@ class DeviceWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
- device_list_max = self._device_list_id_gen.get_current_token()
+ # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
+ # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
+ device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined]
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_stream",
@@ -339,8 +345,9 @@ class DeviceWorkerStore(SQLBaseStore):
# following this stream later.
last_processed_stream_id = from_stream_id
- query_map = {}
- cross_signing_keys_by_user = {}
+ # A map of (user ID, device ID) to (stream ID, context).
+ query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]] = {}
+ cross_signing_keys_by_user: Dict[str, Dict[str, object]] = {}
for user_id, device_id, update_stream_id, update_context in updates:
# Calculate the remaining length budget.
# Note that, for now, each entry in `cross_signing_keys_by_user`
@@ -596,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore):
txn=txn,
table="device_lists_outbound_last_success",
key_names=("destination", "user_id"),
- key_values=((destination, user_id) for user_id, _ in rows),
+ key_values=[(destination, user_id) for user_id, _ in rows],
value_names=("stream_id",),
value_values=((stream_id,) for _, stream_id in rows),
)
@@ -621,7 +628,9 @@ class DeviceWorkerStore(SQLBaseStore):
The new stream ID.
"""
- async with self._device_list_id_gen.get_next() as stream_id:
+ # TODO: this looks like it's _writing_. Should this be on DeviceStore rather
+ # than DeviceWorkerStore?
+ async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -660,7 +669,7 @@ class DeviceWorkerStore(SQLBaseStore):
@trace
async def get_user_devices_from_cache(
- self, query_list: List[Tuple[str, str]]
+ self, query_list: List[Tuple[str, Optional[str]]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.
@@ -686,7 +695,7 @@ class DeviceWorkerStore(SQLBaseStore):
} - users_needing_resync
user_ids_not_in_cache = user_ids - user_ids_in_cache
- results = {}
+ results: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
@@ -697,8 +706,8 @@ class DeviceWorkerStore(SQLBaseStore):
else:
results[user_id] = await self.get_cached_devices_for_user(user_id)
- set_tag("in_cache", results)
- set_tag("not_in_cache", user_ids_not_in_cache)
+ set_tag("in_cache", str(results))
+ set_tag("not_in_cache", str(user_ids_not_in_cache))
return user_ids_not_in_cache, results
@@ -727,7 +736,7 @@ class DeviceWorkerStore(SQLBaseStore):
def get_cached_device_list_changes(
self,
from_key: int,
- ) -> Optional[Set[str]]:
+ ) -> Optional[List[str]]:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""
@@ -737,7 +746,7 @@ class DeviceWorkerStore(SQLBaseStore):
async def get_users_whose_devices_changed(
self,
from_key: int,
- user_ids: Optional[Iterable[str]] = None,
+ user_ids: Optional[Collection[str]] = None,
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
@@ -757,6 +766,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
+ user_ids_to_check: Optional[Collection[str]]
if user_ids is None:
# Get set of all users that have had device list changes since 'from_key'
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
@@ -772,7 +782,7 @@ class DeviceWorkerStore(SQLBaseStore):
return set()
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
- changes = set()
+ changes: Set[str] = set()
stream_id_where_clause = "stream_id > ?"
sql_args = [from_key]
@@ -788,6 +798,9 @@ class DeviceWorkerStore(SQLBaseStore):
"""
# Query device changes with a batch of users at a time
+ # Assertion for mypy's benefit; see also
+ # https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
+ assert user_ids_to_check is not None
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
@@ -854,7 +867,9 @@ class DeviceWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def _get_all_device_list_changes_for_remotes(txn):
+ def _get_all_device_list_changes_for_remotes(
+ txn: Cursor,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
@@ -913,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="get_device_list_last_stream_id_for_remotes",
)
- results = {user_id: None for user_id in user_ids}
+ results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
results.update({row["user_id"]: row["stream_id"] for row in rows})
return results
@@ -1193,6 +1208,65 @@ class DeviceWorkerStore(SQLBaseStore):
return devices
+ @cached()
+ async def _get_min_device_lists_changes_in_room(self) -> int:
+ """Returns the minimum stream ID that we have entries for
+ `device_lists_changes_in_room`
+ """
+
+ return await self.db_pool.simple_select_one_onecol(
+ table="device_lists_changes_in_room",
+ keyvalues={},
+ retcol="COALESCE(MIN(stream_id), 0)",
+ desc="get_min_device_lists_changes_in_room",
+ )
+
+ async def get_device_list_changes_in_rooms(
+ self, room_ids: Collection[str], from_id: int
+ ) -> Optional[Set[str]]:
+ """Return the set of users whose devices have changed in the given rooms
+ since the given stream ID.
+
+ Returns None if the given stream ID is too old.
+ """
+
+ if not room_ids:
+ return set()
+
+ min_stream_id = await self._get_min_device_lists_changes_in_room()
+
+ if min_stream_id > from_id:
+ return None
+
+ sql = """
+ SELECT DISTINCT user_id FROM device_lists_changes_in_room
+ WHERE {clause} AND stream_id >= ?
+ """
+
+ def _get_device_list_changes_in_rooms_txn(
+ txn: LoggingTransaction,
+ clause: str,
+ args: List[Any],
+ ) -> Set[str]:
+ txn.execute(sql.format(clause=clause), args)
+ return {user_id for user_id, in txn}
+
+ changes = set()
+ for chunk in batch_iter(room_ids, 1000):
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", chunk
+ )
+ args.append(from_id)
+
+ changes |= await self.db_pool.runInteraction(
+ "get_device_list_changes_in_rooms",
+ _get_device_list_changes_in_rooms_txn,
+ clause,
+ args,
+ )
+
+ return changes
+
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(
@@ -1240,15 +1314,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
self._remove_duplicate_outbound_pokes,
)
- # a pair of background updates that were added during the 1.14 release cycle,
- # but replaced with 58/06dlols_unique_idx.py
- self.db_pool.updates.register_noop_background_update(
- "device_lists_outbound_last_success_unique_idx",
- )
- self.db_pool.updates.register_noop_background_update(
- "drop_device_lists_outbound_last_success_non_unique_idx",
- )
-
async def _drop_device_list_streams_non_unique_indexes(
self, progress: JsonDict, batch_size: int
) -> int:
@@ -1346,9 +1411,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
- self.device_id_exists_cache = LruCache(
- cache_name="device_id_exists", max_size=10000
- )
+ self.device_id_exists_cache: LruCache[
+ Tuple[str, str], Literal[True]
+ ] = LruCache(cache_name="device_id_exists", max_size=10000)
async def store_device(
self,
@@ -1660,7 +1725,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context,
)
- async with self._device_list_id_gen.get_next_mult(
+ async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined]
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
@@ -1713,7 +1778,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_ids: Iterable[str],
hosts: Collection[str],
stream_ids: List[int],
- context: Dict[str, str],
+ context: Optional[Dict[str, str]],
) -> None:
for host in hosts:
txn.call_after(
@@ -1884,7 +1949,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
[],
)
- async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
+ async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined]
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 9b293475c8..46c0d06157 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -22,11 +22,14 @@ from typing import (
List,
Optional,
Tuple,
+ Union,
cast,
+ overload,
)
import attr
from canonicaljson import encode_canonical_json
+from typing_extensions import Literal
from synapse.api.constants import DeviceKeyAlgorithms
from synapse.appservice import (
@@ -113,7 +116,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.items():
- result = {"device_id": device_id}
+ result: JsonDict = {"device_id": device_id}
keys = device.keys
if keys:
@@ -143,7 +146,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
key data. The key data will be a dict in the same format as the
DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
"""
- set_tag("query_list", query_list)
+ set_tag("query_list", str(query_list))
if not query_list:
return {}
@@ -156,6 +159,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
rv[user_id] = {}
for device_id, device_info in device_keys.items():
r = device_info.keys
+ if r is None:
+ continue
+
r["unsigned"] = {}
display_name = device_info.display_name
if display_name is not None:
@@ -164,13 +170,42 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return rv
+ @overload
+ async def get_e2e_device_keys_and_signatures(
+ self,
+ query_list: Collection[Tuple[str, Optional[str]]],
+ include_all_devices: Literal[False] = False,
+ ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
+ ...
+
+ @overload
+ async def get_e2e_device_keys_and_signatures(
+ self,
+ query_list: Collection[Tuple[str, Optional[str]]],
+ include_all_devices: bool = False,
+ include_deleted_devices: Literal[False] = False,
+ ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
+ ...
+
+ @overload
+ async def get_e2e_device_keys_and_signatures(
+ self,
+ query_list: Collection[Tuple[str, Optional[str]]],
+ include_all_devices: Literal[True],
+ include_deleted_devices: Literal[True],
+ ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+ ...
+
@trace
async def get_e2e_device_keys_and_signatures(
self,
- query_list: List[Tuple[str, Optional[str]]],
+ query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
- ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+ ) -> Union[
+ Dict[str, Dict[str, DeviceKeyLookupResult]],
+ Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]],
+ ]:
"""Fetch a list of device keys
Any cross-signatures made on the keys by the owner of the device are also
@@ -383,7 +418,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
- set_tag("new_keys", new_keys)
+ set_tag("new_keys", str(new_keys))
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
@@ -1044,7 +1079,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
db_autocommit = False
- row = await self.db_pool.runInteraction(
+ claim_row = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
_claim_e2e_one_time_key,
user_id,
@@ -1052,11 +1087,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
algorithm,
db_autocommit=db_autocommit,
)
- if row:
+ if claim_row:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
- device_results[row[0]] = row[1]
+ device_results[claim_row[0]] = claim_row[1]
continue
# No one-time key available, so see if there's a fallback
@@ -1126,7 +1161,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
- set_tag("device_keys", device_keys)
+ set_tag("device_keys", str(device_keys))
old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index eec55b6478..c836078da6 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -33,6 +33,7 @@ from synapse.api.constants import MAX_DEPTH, EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.events import EventBase, make_event_from_dict
+from synapse.logging.opentracing import tag_args, trace
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
@@ -126,6 +127,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
return await self.get_events_as_list(event_ids)
+ @trace
+ @tag_args
async def get_auth_chain_ids(
self,
room_id: str,
@@ -709,6 +712,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
+ @trace
+ @tag_args
async def get_oldest_event_ids_with_depth_in_room(
self, room_id: str
) -> List[Tuple[str, int]]:
@@ -767,6 +772,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)
+ @trace
async def get_insertion_event_backward_extremities_in_room(
self, room_id: str
) -> List[Tuple[str, int]]:
@@ -1339,6 +1345,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
event_results.reverse()
return event_results
+ @trace
+ @tag_args
async def get_successor_events(self, event_id: str) -> List[str]:
"""Fetch all events that have the given event as a prev event
@@ -1375,6 +1383,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
_delete_old_forward_extrem_cache_txn,
)
+ @trace
async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None:
await self.db_pool.simple_upsert(
table="insertion_event_extremities",
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index b019979350..f4a07de2a3 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -12,18 +12,92 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+"""Responsible for storing and fetching push actions / notifications.
+
+There are two main uses for push actions:
+ 1. Sending out push to a user's device; and
+ 2. Tracking per-room per-user notification counts (used in sync requests).
+
+For the former we simply use the `event_push_actions` table, which contains all
+the calculated actions for a given user (which were calculated by the
+`BulkPushRuleEvaluator`).
+
+For the latter we could simply count the number of rows in `event_push_actions`
+table for a given room/user, but in practice this is *very* heavyweight when
+there were a large number of notifications (due to e.g. the user never reading a
+room). Plus, keeping all push actions indefinitely uses a lot of disk space.
+
+To fix these issues, we add a new table `event_push_summary` that tracks
+per-user per-room counts of all notifications that happened before a stream
+ordering S. Thus, to get the notification count for a user / room we can simply
+query a single row in `event_push_summary` and count the number of rows in
+`event_push_actions` with a stream ordering larger than S (and as long as S is
+"recent", the number of rows needing to be scanned will be small).
+
+The `event_push_summary` table is updated via a background job that periodically
+chooses a new stream ordering S' (usually the latest stream ordering), counts
+all notifications in `event_push_actions` between the existing S and S', and
+adds them to the existing counts in `event_push_summary`.
+
+This allows us to delete old rows from `event_push_actions` once those rows have
+been counted and added to `event_push_summary` (we call this process
+"rotation").
+
+
+We need to handle when a user sends a read receipt to the room. Again this is
+done as a background process. For each receipt we clear the row in
+`event_push_summary` and count the number of notifications in
+`event_push_actions` that happened after the receipt but before S, and insert
+that count into `event_push_summary` (If the receipt happened *after* S then we
+simply clear the `event_push_summary`.)
+
+Note that its possible that if the read receipt is for an old event the relevant
+`event_push_actions` rows will have been rotated and we get the wrong count
+(it'll be too low). We accept this as a rare edge case that is unlikely to
+impact the user much (since the vast majority of read receipts will be for the
+latest event).
+
+The last complication is to handle the race where we request the notifications
+counts after a user sends a read receipt into the room, but *before* the
+background update handles the receipt (without any special handling the counts
+would be outdated). We fix this by including in `event_push_summary` the read
+receipt we used when updating `event_push_summary`, and every time we query the
+table we check if that matches the most recent read receipt in the room. If yes,
+continue as above, if not we simply query the `event_push_actions` table
+directly.
+
+Since read receipts are almost always for recent events, scanning the
+`event_push_actions` table in this case is unlikely to be a problem. Even if it
+is a problem, it is temporary until the background job handles the new read
+receipt.
+"""
+
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
import attr
+from synapse.api.constants import ReceiptTypes
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
+from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -79,18 +153,20 @@ class UserPushAction(EmailPushAction):
profile_tag: str
-@attr.s(slots=True, frozen=True, auto_attribs=True)
+@attr.s(slots=True, auto_attribs=True)
class NotifCounts:
"""
The per-user, per-room count of notifications. Used by sync and push.
"""
- notify_count: int
- unread_count: int
- highlight_count: int
+ notify_count: int = 0
+ unread_count: int = 0
+ highlight_count: int = 0
-def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
+def _serialize_action(
+ actions: Collection[Union[Mapping, str]], is_highlight: bool
+) -> str:
"""Custom serializer for actions. This allows us to "compress" common actions.
We use the fact that most users have the same actions for notifs (and for
@@ -119,7 +195,7 @@ def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, st
return DEFAULT_NOTIF_ACTION
-class EventPushActionsWorkerStore(SQLBaseStore):
+class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBaseStore):
def __init__(
self,
database: DatabasePool,
@@ -140,23 +216,30 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._find_stream_orderings_for_times, 10 * 60 * 1000
)
- self._rotate_delay = 3
self._rotate_count = 10000
self._doing_notif_rotation = False
if hs.config.worker.run_background_tasks:
self._rotate_notif_loop = self._clock.looping_call(
- self._rotate_notifs, 30 * 60 * 1000
+ self._rotate_notifs, 30 * 1000
)
- @cached(num_args=3, tree=True, max_entries=5000)
+ self.db_pool.updates.register_background_index_update(
+ "event_push_summary_unique_index",
+ index_name="event_push_summary_unique_index",
+ table="event_push_summary",
+ columns=["user_id", "room_id"],
+ unique=True,
+ replaces_index="event_push_summary_user_rm",
+ )
+
+ @cached(tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
self,
room_id: str,
user_id: str,
- last_read_event_id: Optional[str],
) -> NotifCounts:
"""Get the notification count, the highlight count and the unread message count
- for a given user in a given room after the given read receipt.
+ for a given user in a given room after their latest read receipt.
Note that this function assumes the user to be a current member of the room,
since it's either called by the sync handler to handle joined room entries, or by
@@ -165,20 +248,16 @@ class EventPushActionsWorkerStore(SQLBaseStore):
Args:
room_id: The room to retrieve the counts in.
user_id: The user to retrieve the counts for.
- last_read_event_id: The event associated with the latest read receipt for
- this user in this room. None if no receipt for this user in this room.
Returns
- A dict containing the counts mentioned earlier in this docstring,
- respectively under the keys "notify_count", "highlight_count" and
- "unread_count".
+ A NotifCounts object containing the notification count, the highlight count
+ and the unread message count.
"""
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
user_id,
- last_read_event_id,
)
def _get_unread_counts_by_receipt_txn(
@@ -186,20 +265,23 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn: LoggingTransaction,
room_id: str,
user_id: str,
- last_read_event_id: Optional[str],
) -> NotifCounts:
- stream_ordering = None
+ # Get the stream ordering of the user's latest receipt in the room.
+ result = self.get_last_receipt_for_user_txn(
+ txn,
+ user_id,
+ room_id,
+ receipt_types=(
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ),
+ )
- if last_read_event_id is not None:
- stream_ordering = self.get_stream_id_for_event_txn( # type: ignore[attr-defined]
- txn,
- last_read_event_id,
- allow_none=True,
- )
+ if result:
+ _, stream_ordering = result
- if stream_ordering is None:
- # Either last_read_event_id is None, or it's an event we don't have (e.g.
- # because it's been purged), in which case retrieve the stream ordering for
+ else:
+ # If the user has no receipts in the room, retrieve the stream ordering for
# the latest membership event from this user in this room (which we assume is
# a join).
event_id = self.db_pool.simple_select_one_onecol_txn(
@@ -209,57 +291,159 @@ class EventPushActionsWorkerStore(SQLBaseStore):
retcol="event_id",
)
- stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) # type: ignore[attr-defined]
+ stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering
)
def _get_unread_counts_by_pos_txn(
- self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ user_id: str,
+ receipt_stream_ordering: int,
) -> NotifCounts:
- sql = (
- "SELECT"
- " COUNT(CASE WHEN notif = 1 THEN 1 END),"
- " COUNT(CASE WHEN highlight = 1 THEN 1 END),"
- " COUNT(CASE WHEN unread = 1 THEN 1 END)"
- " FROM event_push_actions ea"
- " WHERE user_id = ?"
- " AND room_id = ?"
- " AND stream_ordering > ?"
- )
+ """Get the number of unread messages for a user/room that have happened
+ since the given stream ordering.
- txn.execute(sql, (user_id, room_id, stream_ordering))
- row = txn.fetchone()
+ Args:
+ txn: The database transaction.
+ room_id: The room ID to get unread counts for.
+ user_id: The user ID to get unread counts for.
+ receipt_stream_ordering: The stream ordering of the user's latest
+ receipt in the room. If there are no receipts, the stream ordering
+ of the user's join event.
- (notif_count, highlight_count, unread_count) = (0, 0, 0)
+ Returns
+ A NotifCounts object containing the notification count, the highlight count
+ and the unread message count.
+ """
- if row:
- (notif_count, highlight_count, unread_count) = row
+ counts = NotifCounts()
+ # First we pull the counts from the summary table.
+ #
+ # We check that `last_receipt_stream_ordering` matches the stream
+ # ordering given. If it doesn't match then a new read receipt has arrived and
+ # we haven't yet updated the counts in `event_push_summary` to reflect
+ # that; in that case we simply ignore `event_push_summary` counts
+ # and do a manual count of all of the rows in the `event_push_actions` table
+ # for this user/room.
+ #
+ # If `last_receipt_stream_ordering` is null then that means it's up to
+ # date (as the row was written by an older version of Synapse that
+ # updated `event_push_summary` synchronously when persisting a new read
+ # receipt).
txn.execute(
"""
- SELECT notif_count, unread_count FROM event_push_summary
- WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+ SELECT stream_ordering, notif_count, COALESCE(unread_count, 0)
+ FROM event_push_summary
+ WHERE room_id = ? AND user_id = ?
+ AND (
+ (last_receipt_stream_ordering IS NULL AND stream_ordering > ?)
+ OR last_receipt_stream_ordering = ?
+ )
""",
- (room_id, user_id, stream_ordering),
+ (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering),
)
row = txn.fetchone()
+ summary_stream_ordering = 0
+ if row:
+ summary_stream_ordering = row[0]
+ counts.notify_count += row[1]
+ counts.unread_count += row[2]
+
+ # Next we need to count highlights, which aren't summarised
+ sql = """
+ SELECT COUNT(*) FROM event_push_actions
+ WHERE user_id = ?
+ AND room_id = ?
+ AND stream_ordering > ?
+ AND highlight = 1
+ """
+ txn.execute(sql, (user_id, room_id, receipt_stream_ordering))
+ row = txn.fetchone()
if row:
- notif_count += row[0]
-
- if row[1] is not None:
- # The unread_count column of event_push_summary is NULLable, so we need
- # to make sure we don't try increasing the unread counts if it's NULL
- # for this row.
- unread_count += row[1]
-
- return NotifCounts(
- notify_count=notif_count,
- unread_count=unread_count,
- highlight_count=highlight_count,
+ counts.highlight_count += row[0]
+
+ # Finally we need to count push actions that aren't included in the
+ # summary returned above. This might be due to recent events that haven't
+ # been summarised yet or the summary is out of date due to a recent read
+ # receipt.
+ start_unread_stream_ordering = max(
+ receipt_stream_ordering, summary_stream_ordering
)
+ notify_count, unread_count = self._get_notif_unread_count_for_user_room(
+ txn, room_id, user_id, start_unread_stream_ordering
+ )
+
+ counts.notify_count += notify_count
+ counts.unread_count += unread_count
+
+ return counts
+
+ def _get_notif_unread_count_for_user_room(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ user_id: str,
+ stream_ordering: int,
+ max_stream_ordering: Optional[int] = None,
+ ) -> Tuple[int, int]:
+ """Returns the notify and unread counts from `event_push_actions` for
+ the given user/room in the given range.
+
+ Does not consult `event_push_summary` table, which may include push
+ actions that have been deleted from `event_push_actions` table.
+
+ Args:
+ txn: The database transaction.
+ room_id: The room ID to get unread counts for.
+ user_id: The user ID to get unread counts for.
+ stream_ordering: The (exclusive) minimum stream ordering to consider.
+ max_stream_ordering: The (inclusive) maximum stream ordering to consider.
+ If this is not given, then no maximum is applied.
+
+ Return:
+ A tuple of the notif count and unread count in the given range.
+ """
+
+ # If there have been no events in the room since the stream ordering,
+ # there can't be any push actions either.
+ if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
+ return 0, 0
+
+ clause = ""
+ args = [user_id, room_id, stream_ordering]
+ if max_stream_ordering is not None:
+ clause = "AND ea.stream_ordering <= ?"
+ args.append(max_stream_ordering)
+
+ # If the max stream ordering is less than the min stream ordering,
+ # then obviously there are zero push actions in that range.
+ if max_stream_ordering <= stream_ordering:
+ return 0, 0
+
+ sql = f"""
+ SELECT
+ COUNT(CASE WHEN notif = 1 THEN 1 END),
+ COUNT(CASE WHEN unread = 1 THEN 1 END)
+ FROM event_push_actions ea
+ WHERE user_id = ?
+ AND room_id = ?
+ AND ea.stream_ordering > ?
+ {clause}
+ """
+
+ txn.execute(sql, args)
+ row = txn.fetchone()
+
+ if row:
+ return cast(Tuple[int, int], row)
+
+ return 0, 0
async def get_push_action_users_in_range(
self, min_stream_ordering: int, max_stream_ordering: int
@@ -274,6 +458,31 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
+ def _get_receipts_by_room_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> List[Tuple[str, int]]:
+ receipt_types_clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ),
+ )
+
+ sql = f"""
+ SELECT room_id, MAX(stream_ordering)
+ FROM receipts_linearized
+ INNER JOIN events USING (room_id, event_id)
+ WHERE {receipt_types_clause}
+ AND user_id = ?
+ GROUP BY room_id
+ """
+
+ args.extend((user_id,))
+ txn.execute(sql, args)
+ return cast(List[Tuple[str, int]], txn.fetchall())
+
async def get_unread_push_actions_for_user_in_range_for_http(
self,
user_id: str,
@@ -296,81 +505,46 @@ class EventPushActionsWorkerStore(SQLBaseStore):
The list will be ordered by ascending stream_ordering.
The list will have between 0~limit entries.
"""
- # find rooms that have a read receipt in them and return the next
- # push actions
- def get_after_receipt(
- txn: LoggingTransaction,
- ) -> List[Tuple[str, str, int, str, bool]]:
- # find rooms that have a read receipt in them and return the next
- # push actions
- sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " ep.highlight "
- " FROM ("
- " SELECT room_id,"
- " MAX(stream_ordering) as stream_ordering"
- " FROM events"
- " INNER JOIN receipts_linearized USING (room_id, event_id)"
- " WHERE receipt_type = 'm.read' AND user_id = ?"
- " GROUP BY room_id"
- ") AS rl,"
- " event_push_actions AS ep"
- " WHERE"
- " ep.room_id = rl.room_id"
- " AND ep.stream_ordering > rl.stream_ordering"
- " AND ep.user_id = ?"
- " AND ep.stream_ordering > ?"
- " AND ep.stream_ordering <= ?"
- " AND ep.notif = 1"
- " ORDER BY ep.stream_ordering ASC LIMIT ?"
- )
- args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
- txn.execute(sql, args)
- return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
- after_read_receipt = await self.db_pool.runInteraction(
- "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
+ receipts_by_room = dict(
+ await self.db_pool.runInteraction(
+ "get_unread_push_actions_for_user_in_range_http_receipts",
+ self._get_receipts_by_room_txn,
+ user_id=user_id,
+ ),
)
- # There are rooms with push actions in them but you don't have a read receipt in
- # them e.g. rooms you've been invited to, so get push actions for rooms which do
- # not have read receipts in them too.
- def get_no_receipt(
+ def get_push_actions_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool]]:
- sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " ep.highlight "
- " FROM event_push_actions AS ep"
- " INNER JOIN events AS e USING (room_id, event_id)"
- " WHERE"
- " ep.room_id NOT IN ("
- " SELECT room_id FROM receipts_linearized"
- " WHERE receipt_type = 'm.read' AND user_id = ?"
- " GROUP BY room_id"
- " )"
- " AND ep.user_id = ?"
- " AND ep.stream_ordering > ?"
- " AND ep.stream_ordering <= ?"
- " AND ep.notif = 1"
- " ORDER BY ep.stream_ordering ASC LIMIT ?"
- )
- args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
- txn.execute(sql, args)
+ sql = """
+ SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
+ FROM event_push_actions AS ep
+ WHERE
+ ep.user_id = ?
+ AND ep.stream_ordering > ?
+ AND ep.stream_ordering <= ?
+ AND ep.notif = 1
+ ORDER BY ep.stream_ordering ASC LIMIT ?
+ """
+ txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
- no_read_receipt = await self.db_pool.runInteraction(
- "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
+ push_actions = await self.db_pool.runInteraction(
+ "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
)
notifs = [
HttpPushAction(
- event_id=row[0],
- room_id=row[1],
- stream_ordering=row[2],
- actions=_deserialize_action(row[3], row[4]),
+ event_id=event_id,
+ room_id=room_id,
+ stream_ordering=stream_ordering,
+ actions=_deserialize_action(actions, highlight),
)
- for row in after_read_receipt + no_read_receipt
+ for event_id, room_id, stream_ordering, actions, highlight in push_actions
+ # Only include push actions with a stream ordering after any receipt, or without any
+ # receipt present (invited to but never read rooms).
+ if stream_ordering > receipts_by_room.get(room_id, 0)
]
# Now sort it so it's ordered correctly, since currently it will
@@ -405,82 +579,50 @@ class EventPushActionsWorkerStore(SQLBaseStore):
The list will be ordered by descending received_ts.
The list will have between 0~limit entries.
"""
- # find rooms that have a read receipt in them and return the most recent
- # push actions
- def get_after_receipt(
- txn: LoggingTransaction,
- ) -> List[Tuple[str, str, int, str, bool, int]]:
- sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " ep.highlight, e.received_ts"
- " FROM ("
- " SELECT room_id,"
- " MAX(stream_ordering) as stream_ordering"
- " FROM events"
- " INNER JOIN receipts_linearized USING (room_id, event_id)"
- " WHERE receipt_type = 'm.read' AND user_id = ?"
- " GROUP BY room_id"
- ") AS rl,"
- " event_push_actions AS ep"
- " INNER JOIN events AS e USING (room_id, event_id)"
- " WHERE"
- " ep.room_id = rl.room_id"
- " AND ep.stream_ordering > rl.stream_ordering"
- " AND ep.user_id = ?"
- " AND ep.stream_ordering > ?"
- " AND ep.stream_ordering <= ?"
- " AND ep.notif = 1"
- " ORDER BY ep.stream_ordering DESC LIMIT ?"
- )
- args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
- txn.execute(sql, args)
- return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
- after_read_receipt = await self.db_pool.runInteraction(
- "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
+ receipts_by_room = dict(
+ await self.db_pool.runInteraction(
+ "get_unread_push_actions_for_user_in_range_email_receipts",
+ self._get_receipts_by_room_txn,
+ user_id=user_id,
+ ),
)
- # There are rooms with push actions in them but you don't have a read receipt in
- # them e.g. rooms you've been invited to, so get push actions for rooms which do
- # not have read receipts in them too.
- def get_no_receipt(
+ def get_push_actions_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool, int]]:
- sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " ep.highlight, e.received_ts"
- " FROM event_push_actions AS ep"
- " INNER JOIN events AS e USING (room_id, event_id)"
- " WHERE"
- " ep.room_id NOT IN ("
- " SELECT room_id FROM receipts_linearized"
- " WHERE receipt_type = 'm.read' AND user_id = ?"
- " GROUP BY room_id"
- " )"
- " AND ep.user_id = ?"
- " AND ep.stream_ordering > ?"
- " AND ep.stream_ordering <= ?"
- " AND ep.notif = 1"
- " ORDER BY ep.stream_ordering DESC LIMIT ?"
- )
- args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
- txn.execute(sql, args)
+ sql = """
+ SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+ ep.highlight, e.received_ts
+ FROM event_push_actions AS ep
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ ep.user_id = ?
+ AND ep.stream_ordering > ?
+ AND ep.stream_ordering <= ?
+ AND ep.notif = 1
+ ORDER BY ep.stream_ordering DESC LIMIT ?
+ """
+ txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
- no_read_receipt = await self.db_pool.runInteraction(
- "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
+ push_actions = await self.db_pool.runInteraction(
+ "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
)
# Make a list of dicts from the two sets of results.
notifs = [
EmailPushAction(
- event_id=row[0],
- room_id=row[1],
- stream_ordering=row[2],
- actions=_deserialize_action(row[3], row[4]),
- received_ts=row[5],
+ event_id=event_id,
+ room_id=room_id,
+ stream_ordering=stream_ordering,
+ actions=_deserialize_action(actions, highlight),
+ received_ts=received_ts,
)
- for row in after_read_receipt + no_read_receipt
+ for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions
+ # Only include push actions with a stream ordering after any receipt, or without any
+ # receipt present (invited to but never read rooms).
+ if stream_ordering > receipts_by_room.get(room_id, 0)
]
# Now sort it so it's ordered correctly, since currently it will
@@ -526,7 +668,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
async def add_push_actions_to_staging(
self,
event_id: str,
- user_id_actions: Dict[str, List[Union[dict, str]]],
+ user_id_actions: Dict[str, Collection[Union[Mapping, str]]],
count_as_unread: bool,
) -> None:
"""Add the push actions for the event to the push action staging area.
@@ -543,7 +685,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# This is a helper function for generating the necessary tuple that
# can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(
- user_id: str, actions: List[Union[dict, str]]
+ user_id: str, actions: Collection[Union[Mapping, str]]
) -> Tuple[str, str, str, int, int, int]:
is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
@@ -556,26 +698,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
int(count_as_unread), # unread column
)
- def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
- # We don't use simple_insert_many here to avoid the overhead
- # of generating lists of dicts.
-
- sql = """
- INSERT INTO event_push_actions_staging
- (event_id, user_id, actions, notif, highlight, unread)
- VALUES (?, ?, ?, ?, ?, ?)
- """
-
- txn.execute_batch(
- sql,
- (
- _gen_entry(user_id, actions)
- for user_id, actions in user_id_actions.items()
- ),
- )
-
- return await self.db_pool.runInteraction(
- "add_push_actions_to_staging", _add_push_actions_to_staging_txn
+ await self.db_pool.simple_insert_many(
+ "event_push_actions_staging",
+ keys=("event_id", "user_id", "actions", "notif", "highlight", "unread"),
+ values=[
+ _gen_entry(user_id, actions)
+ for user_id, actions in user_id_actions.items()
+ ],
+ desc="add_push_actions_to_staging",
)
async def remove_push_actions_from_staging(self, event_id: str) -> None:
@@ -689,12 +819,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# [10, <none>, 20], we should treat this as being equivalent to
# [10, 10, 20].
#
- sql = (
- "SELECT received_ts FROM events"
- " WHERE stream_ordering <= ?"
- " ORDER BY stream_ordering DESC"
- " LIMIT 1"
- )
+ sql = """
+ SELECT received_ts FROM events
+ WHERE stream_ordering <= ?
+ ORDER BY stream_ordering DESC
+ LIMIT 1
+ """
while range_end - range_start > 0:
middle = (range_end + range_start) // 2
@@ -722,14 +852,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self, stream_ordering: int
) -> Optional[int]:
def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
- sql = (
- "SELECT e.received_ts"
- " FROM event_push_actions AS ep"
- " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
- " WHERE ep.stream_ordering > ? AND notif = 1"
- " ORDER BY ep.stream_ordering ASC"
- " LIMIT 1"
- )
+ sql = """
+ SELECT e.received_ts
+ FROM event_push_actions AS ep
+ JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id
+ WHERE ep.stream_ordering > ? AND notif = 1
+ ORDER BY ep.stream_ordering ASC
+ LIMIT 1
+ """
txn.execute(sql, (stream_ordering,))
return cast(Optional[Tuple[int]], txn.fetchone())
@@ -745,6 +875,19 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._doing_notif_rotation = True
try:
+ # First we recalculate push summaries and delete stale push actions
+ # for rooms/users with new receipts.
+ while True:
+ logger.debug("Handling new receipts")
+
+ caught_up = await self.db_pool.runInteraction(
+ "_handle_new_receipts_for_notifs_txn",
+ self._handle_new_receipts_for_notifs_txn,
+ )
+ if caught_up:
+ break
+
+ # Then we update the event push summaries for any new events
while True:
logger.info("Rotating notifications")
@@ -753,15 +896,129 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
if caught_up:
break
- await self.hs.get_clock().sleep(self._rotate_delay)
+
+ # Finally we clear out old event push actions.
+ await self._remove_old_push_actions_that_have_rotated()
finally:
self._doing_notif_rotation = False
+ def _handle_new_receipts_for_notifs_txn(self, txn: LoggingTransaction) -> bool:
+ """Check for new read receipts and delete from event push actions.
+
+ Any push actions which predate the user's most recent read receipt are
+ now redundant, so we can remove them from `event_push_actions` and
+ update `event_push_summary`.
+
+ Returns true if all new receipts have been processed.
+ """
+
+ limit = 100
+
+ # The (inclusive) receipt stream ID that was previously processed..
+ min_receipts_stream_id = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="event_push_summary_last_receipt_stream_id",
+ keyvalues={},
+ retcol="stream_id",
+ )
+
+ max_receipts_stream_id = self._receipts_id_gen.get_current_token()
+
+ # The (inclusive) event stream ordering that was previously summarised.
+ old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="event_push_summary_stream_ordering",
+ keyvalues={},
+ retcol="stream_ordering",
+ )
+
+ sql = """
+ SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering
+ FROM receipts_linearized AS r
+ INNER JOIN events AS e USING (event_id)
+ WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ?
+ ORDER BY r.stream_id ASC
+ LIMIT ?
+ """
+
+ # We only want local users, so we add a dodgy filter to the above query
+ # and recheck it below.
+ user_filter = "%:" + self.hs.hostname
+
+ txn.execute(
+ sql,
+ (
+ min_receipts_stream_id,
+ max_receipts_stream_id,
+ user_filter,
+ limit,
+ ),
+ )
+ rows = txn.fetchall()
+
+ # For each new read receipt we delete push actions from before it and
+ # recalculate the summary.
+ for _, room_id, user_id, stream_ordering in rows:
+ # Only handle our own read receipts.
+ if not self.hs.is_mine_id(user_id):
+ continue
+
+ txn.execute(
+ """
+ DELETE FROM event_push_actions
+ WHERE room_id = ?
+ AND user_id = ?
+ AND stream_ordering <= ?
+ AND highlight = 0
+ """,
+ (room_id, user_id, stream_ordering),
+ )
+
+ # Fetch the notification counts between the stream ordering of the
+ # latest receipt and what was previously summarised.
+ notif_count, unread_count = self._get_notif_unread_count_for_user_room(
+ txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
+ )
+
+ # Replace the previous summary with the new counts.
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="event_push_summary",
+ keyvalues={"room_id": room_id, "user_id": user_id},
+ values={
+ "notif_count": notif_count,
+ "unread_count": unread_count,
+ "stream_ordering": old_rotate_stream_ordering,
+ "last_receipt_stream_ordering": stream_ordering,
+ },
+ )
+
+ # We always update `event_push_summary_last_receipt_stream_id` to
+ # ensure that we don't rescan the same receipts for remote users.
+
+ upper_limit = max_receipts_stream_id
+ if len(rows) >= limit:
+ # If we pulled out a limited number of rows we only update the
+ # position to the last receipt we processed, so we continue
+ # processing the rest next iteration.
+ upper_limit = rows[-1][0]
+
+ self.db_pool.simple_update_txn(
+ txn,
+ table="event_push_summary_last_receipt_stream_id",
+ keyvalues={},
+ updatevalues={"stream_id": upper_limit},
+ )
+
+ return len(rows) < limit
+
def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool:
- """Archives older notifications into event_push_summary. Returns whether
- the archiving process has caught up or not.
+ """Archives older notifications (from event_push_actions) into event_push_summary.
+
+ Returns whether the archiving process has caught up or not.
"""
+ # The (inclusive) event stream ordering that was previously summarised.
old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
@@ -776,50 +1033,64 @@ class EventPushActionsWorkerStore(SQLBaseStore):
SELECT stream_ordering FROM event_push_actions
WHERE stream_ordering > ?
ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
- """,
+ """,
(old_rotate_stream_ordering, self._rotate_count),
)
stream_row = txn.fetchone()
if stream_row:
(offset_stream_ordering,) = stream_row
- assert self.stream_ordering_day_ago is not None
+
+ # We need to bound by the current token to ensure that we handle
+ # out-of-order writes correctly.
rotate_to_stream_ordering = min(
- self.stream_ordering_day_ago, offset_stream_ordering
+ offset_stream_ordering, self._stream_id_gen.get_current_token()
)
- caught_up = offset_stream_ordering >= self.stream_ordering_day_ago
+ caught_up = False
else:
- rotate_to_stream_ordering = self.stream_ordering_day_ago
+ rotate_to_stream_ordering = self._stream_id_gen.get_current_token()
caught_up = True
logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
- self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering)
+ self._rotate_notifs_before_txn(
+ txn, old_rotate_stream_ordering, rotate_to_stream_ordering
+ )
- # We have caught up iff we were limited by `stream_ordering_day_ago`
return caught_up
def _rotate_notifs_before_txn(
- self, txn: LoggingTransaction, rotate_to_stream_ordering: int
+ self,
+ txn: LoggingTransaction,
+ old_rotate_stream_ordering: int,
+ rotate_to_stream_ordering: int,
) -> None:
- old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="event_push_summary_stream_ordering",
- keyvalues={},
- retcol="stream_ordering",
- )
+ """Archives older notifications (from event_push_actions) into event_push_summary.
+
+ Any event_push_actions between old_rotate_stream_ordering (exclusive) and
+ rotate_to_stream_ordering (inclusive) will be added to the event_push_summary
+ table.
+
+ Args:
+ txn: The database transaction.
+ old_rotate_stream_ordering: The previous maximum event stream ordering.
+ rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
+ """
# Calculate the new counts that should be upserted into event_push_summary
sql = """
SELECT user_id, room_id,
coalesce(old.%s, 0) + upd.cnt,
- upd.stream_ordering,
- old.user_id
+ upd.stream_ordering
FROM (
SELECT user_id, room_id, count(*) as cnt,
- max(stream_ordering) as stream_ordering
- FROM event_push_actions
- WHERE ? <= stream_ordering AND stream_ordering < ?
- AND highlight = 0
+ max(ea.stream_ordering) as stream_ordering
+ FROM event_push_actions AS ea
+ LEFT JOIN event_push_summary AS old USING (user_id, room_id)
+ WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ?
+ AND (
+ old.last_receipt_stream_ordering IS NULL
+ OR old.last_receipt_stream_ordering < ea.stream_ordering
+ )
AND %s = 1
GROUP BY user_id, room_id
) AS upd
@@ -842,7 +1113,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
summaries[(row[0], row[1])] = _EventPushSummary(
unread_count=row[2],
stream_ordering=row[3],
- old_user_id=row[4],
notif_count=0,
)
@@ -863,115 +1133,93 @@ class EventPushActionsWorkerStore(SQLBaseStore):
summaries[(row[0], row[1])] = _EventPushSummary(
unread_count=0,
stream_ordering=row[3],
- old_user_id=row[4],
notif_count=row[2],
)
logger.info("Rotating notifications, handling %d rows", len(summaries))
- # If the `old.user_id` above is NULL then we know there isn't already an
- # entry in the table, so we simply insert it. Otherwise we update the
- # existing table.
- self.db_pool.simple_insert_many_txn(
+ self.db_pool.simple_upsert_many_txn(
txn,
table="event_push_summary",
- keys=(
- "user_id",
- "room_id",
- "notif_count",
- "unread_count",
- "stream_ordering",
- ),
- values=[
+ key_names=("user_id", "room_id"),
+ key_values=[(user_id, room_id) for user_id, room_id in summaries],
+ value_names=("notif_count", "unread_count", "stream_ordering"),
+ value_values=[
(
- user_id,
- room_id,
summary.notif_count,
summary.unread_count,
summary.stream_ordering,
)
- for ((user_id, room_id), summary) in summaries.items()
- if summary.old_user_id is None
+ for summary in summaries.values()
],
)
- txn.execute_batch(
- """
- UPDATE event_push_summary
- SET notif_count = ?, unread_count = ?, stream_ordering = ?
- WHERE user_id = ? AND room_id = ?
- """,
- (
- (
- summary.notif_count,
- summary.unread_count,
- summary.stream_ordering,
- user_id,
- room_id,
- )
- for ((user_id, room_id), summary) in summaries.items()
- if summary.old_user_id is not None
- ),
- )
-
- txn.execute(
- "DELETE FROM event_push_actions"
- " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0",
- (old_rotate_stream_ordering, rotate_to_stream_ordering),
- )
-
- logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
-
txn.execute(
"UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
(rotate_to_stream_ordering,),
)
- def _remove_old_push_actions_before_txn(
- self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
- ) -> None:
- """
- Purges old push actions for a user and room before a given
- stream_ordering.
-
- We however keep a months worth of highlighted notifications, so that
- users can still get a list of recent highlights.
+ async def _remove_old_push_actions_that_have_rotated(self) -> None:
+ """Clear out old push actions that have been summarised."""
- Args:
- txn: The transaction
- room_id: Room ID to delete from
- user_id: user ID to delete for
- stream_ordering: The lowest stream ordering which will
- not be deleted.
- """
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate,
- (room_id, user_id),
+ # We want to clear out anything that is older than a day that *has* already
+ # been rotated.
+ rotated_upto_stream_ordering = await self.db_pool.simple_select_one_onecol(
+ table="event_push_summary_stream_ordering",
+ keyvalues={},
+ retcol="stream_ordering",
)
- # We need to join on the events table to get the received_ts for
- # event_push_actions and sqlite won't let us use a join in a delete so
- # we can't just delete where received_ts < x. Furthermore we can
- # only identify event_push_actions by a tuple of room_id, event_id
- # we we can't use a subquery.
- # Instead, we look up the stream ordering for the last event in that
- # room received before the threshold time and delete event_push_actions
- # in the room with a stream_odering before that.
- txn.execute(
- "DELETE FROM event_push_actions "
- " WHERE user_id = ? AND room_id = ? AND "
- " stream_ordering <= ?"
- " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
- (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
+ max_stream_ordering_to_delete = min(
+ rotated_upto_stream_ordering, self.stream_ordering_day_ago
)
- txn.execute(
- """
- DELETE FROM event_push_summary
- WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
- """,
- (room_id, user_id, stream_ordering),
- )
+ def remove_old_push_actions_that_have_rotated_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
+ # We don't want to clear out too much at a time, so we bound our
+ # deletes.
+ batch_size = self._rotate_count
+
+ txn.execute(
+ """
+ SELECT stream_ordering FROM event_push_actions
+ WHERE stream_ordering <= ? AND highlight = 0
+ ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
+ """,
+ (
+ max_stream_ordering_to_delete,
+ batch_size,
+ ),
+ )
+ stream_row = txn.fetchone()
+
+ if stream_row:
+ (stream_ordering,) = stream_row
+ else:
+ stream_ordering = max_stream_ordering_to_delete
+
+ # We need to use a inclusive bound here to handle the case where a
+ # single stream ordering has more than `batch_size` rows.
+ txn.execute(
+ """
+ DELETE FROM event_push_actions
+ WHERE stream_ordering <= ? AND highlight = 0
+ """,
+ (stream_ordering,),
+ )
+
+ logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
+
+ return txn.rowcount < batch_size
+
+ while True:
+ done = await self.db_pool.runInteraction(
+ "_remove_old_push_actions_that_have_rotated",
+ remove_old_push_actions_that_have_rotated_txn,
+ )
+ if done:
+ break
class EventPushActionsStore(EventPushActionsWorkerStore):
@@ -1000,6 +1248,16 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
where_clause="highlight=1",
)
+ # Add index to make deleting old push actions faster.
+ self.db_pool.updates.register_background_index_update(
+ "event_push_actions_stream_highlight_index",
+ index_name="event_push_actions_stream_highlight_index",
+ table="event_push_actions",
+ columns=["highlight", "stream_ordering"],
+ where_clause="highlight=0",
+ psql_only=True,
+ )
+
async def get_push_actions_for_user(
self,
user_id: str,
@@ -1024,16 +1282,18 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# NB. This assumes event_ids are globally unique since
# it makes the query easier to index
- sql = (
- "SELECT epa.event_id, epa.room_id,"
- " epa.stream_ordering, epa.topological_ordering,"
- " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
- " FROM event_push_actions epa, events e"
- " WHERE epa.event_id = e.event_id"
- " AND epa.user_id = ? %s"
- " AND epa.notif = 1"
- " ORDER BY epa.stream_ordering DESC"
- " LIMIT ?" % (before_clause,)
+ sql = """
+ SELECT epa.event_id, epa.room_id,
+ epa.stream_ordering, epa.topological_ordering,
+ epa.actions, epa.highlight, epa.profile_tag, e.received_ts
+ FROM event_push_actions epa, events e
+ WHERE epa.event_id = e.event_id
+ AND epa.user_id = ? %s
+ AND epa.notif = 1
+ ORDER BY epa.stream_ordering DESC
+ LIMIT ?
+ """ % (
+ before_clause,
)
txn.execute(sql, args)
return cast(
@@ -1056,7 +1316,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
]
-def _action_has_highlight(actions: List[Union[dict, str]]) -> bool:
+def _action_has_highlight(actions: Collection[Union[Mapping, str]]) -> bool:
for action in actions:
if not isinstance(action, dict):
continue
@@ -1075,5 +1335,4 @@ class _EventPushSummary:
unread_count: int
stream_ordering: int
- old_user_id: str
notif_count: int
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 17e35cf63e..a4010ee28d 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -16,6 +16,7 @@
import itertools
import logging
from collections import OrderedDict
+from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
@@ -35,9 +36,11 @@ from prometheus_client import Counter
import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
+from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
+from synapse.logging.opentracing import trace
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -46,7 +49,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
-from synapse.storage.engines.postgres import PostgresEngine
+from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import JsonDict, StateMap, get_domain_from_id
@@ -69,6 +72,24 @@ event_counter = Counter(
)
+class PartialStateConflictError(SynapseError):
+ """An internal error raised when attempting to persist an event with partial state
+ after the room containing the event has been un-partial stated.
+
+ This error should be handled by recomputing the event context and trying again.
+
+ This error has an HTTP status code so that it can be transported over replication.
+ It should not be exposed to clients.
+ """
+
+ def __init__(self) -> None:
+ super().__init__(
+ HTTPStatus.CONFLICT,
+ msg="Cannot persist partial state event in un-partial stated room",
+ errcode=Codes.UNKNOWN,
+ )
+
+
@attr.s(slots=True, auto_attribs=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -125,6 +146,7 @@ class PersistEventsStore:
self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+ @trace
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
@@ -154,6 +176,10 @@ class PersistEventsStore:
Returns:
Resolves when the events have been persisted
+
+ Raises:
+ PartialStateConflictError: if attempting to persist a partial state event in
+ a room that has been un-partial stated.
"""
# We want to calculate the stream orderings as late as possible, as
@@ -354,6 +380,9 @@ class PersistEventsStore:
For each room, a list of the event ids which are the forward
extremities.
+ Raises:
+ PartialStateConflictError: if attempting to persist a partial state event in
+ a room that has been un-partial stated.
"""
state_delta_for_room = state_delta_for_room or {}
new_forward_extremities = new_forward_extremities or {}
@@ -980,16 +1009,16 @@ class PersistEventsStore:
self,
room_id: str,
state_delta: DeltaState,
- stream_id: int,
) -> None:
"""Update the current state stored in the datatabase for the given room"""
- await self.db_pool.runInteraction(
- "update_current_state",
- self._update_current_state_txn,
- state_delta_by_room={room_id: state_delta},
- stream_id=stream_id,
- )
+ async with self._stream_id_gen.get_next() as stream_ordering:
+ await self.db_pool.runInteraction(
+ "update_current_state",
+ self._update_current_state_txn,
+ state_delta_by_room={room_id: state_delta},
+ stream_id=stream_ordering,
+ )
def _update_current_state_txn(
self,
@@ -1266,7 +1295,7 @@ class PersistEventsStore:
depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
- txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
+ self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
# Then update the `stream_ordering` position to mark the latest
# event as the front of the room. This should not be done for
# backfilled events because backfilled events have negative
@@ -1304,6 +1333,10 @@ class PersistEventsStore:
Returns:
new list, without events which are already in the events table.
+
+ Raises:
+ PartialStateConflictError: if attempting to persist a partial state event in
+ a room that has been un-partial stated.
"""
txn.execute(
"SELECT event_id, outlier FROM events WHERE event_id in (%s)"
@@ -1315,9 +1348,24 @@ class PersistEventsStore:
event_id: outlier for event_id, outlier in txn
}
+ logger.debug(
+ "_update_outliers_txn: events=%s have_persisted=%s",
+ [ev.event_id for ev, _ in events_and_contexts],
+ have_persisted,
+ )
+
to_remove = set()
for event, context in events_and_contexts:
- if event.event_id not in have_persisted:
+ outlier_persisted = have_persisted.get(event.event_id)
+ logger.debug(
+ "_update_outliers_txn: event=%s outlier=%s outlier_persisted=%s",
+ event.event_id,
+ event.internal_metadata.is_outlier(),
+ outlier_persisted,
+ )
+
+ # Ignore events which we haven't persisted at all
+ if outlier_persisted is None:
continue
to_remove.add(event)
@@ -1327,7 +1375,6 @@ class PersistEventsStore:
# was an outlier or not - what we have is at least as good.
continue
- outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
# We received a copy of an event that we had already stored as
# an outlier in the database. We now have some state at that event
@@ -1338,7 +1385,10 @@ class PersistEventsStore:
# events down /sync. In general they will be historical events, so that
# doesn't matter too much, but that is not always the case.
- logger.info("Updating state for ex-outlier event %s", event.event_id)
+ logger.info(
+ "_update_outliers_txn: Updating state for ex-outlier event %s",
+ event.event_id,
+ )
# insert into event_to_state_groups.
try:
@@ -1442,7 +1492,7 @@ class PersistEventsStore:
event.sender,
"url" in event.content and isinstance(event.content["url"], str),
event.get_state_key(),
- context.rejected or None,
+ context.rejected,
)
for event, context in events_and_contexts
),
@@ -1638,13 +1688,13 @@ class PersistEventsStore:
if not row["rejects"] and not row["redacts"]:
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
- def prefill() -> None:
+ async def prefill() -> None:
for cache_entry in to_prefill:
- self.store._get_event_cache.set(
+ await self.store._get_event_cache.set(
(cache_entry.event.event_id,), cache_entry
)
- txn.call_after(prefill)
+ txn.async_call_after(prefill)
def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
"""Invalidate the caches for the redacted event.
@@ -1653,7 +1703,7 @@ class PersistEventsStore:
_invalidate_caches_for_event.
"""
assert event.redacts is not None
- txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
+ self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
@@ -1766,6 +1816,18 @@ class PersistEventsStore:
self.store.get_invited_rooms_for_local_user.invalidate,
(event.state_key,),
)
+ txn.call_after(
+ self.store.get_local_users_in_room.invalidate,
+ (event.room_id,),
+ )
+ txn.call_after(
+ self.store.get_number_joined_users_in_room.invalidate,
+ (event.room_id,),
+ )
+ txn.call_after(
+ self.store.get_user_in_room_with_profile.invalidate,
+ (event.room_id, event.state_key),
+ )
# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.
@@ -2215,6 +2277,11 @@ class PersistEventsStore:
txn: LoggingTransaction,
events_and_contexts: Collection[Tuple[EventBase, EventContext]],
) -> None:
+ """
+ Raises:
+ PartialStateConflictError: if attempting to persist a partial state event in
+ a room that has been un-partial stated.
+ """
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
@@ -2239,19 +2306,37 @@ class PersistEventsStore:
# if we have partial state for these events, record the fact. (This happens
# here rather than in _store_event_txn because it also needs to happen when
# we de-outlier an event.)
- self.db_pool.simple_insert_many_txn(
- txn,
- table="partial_state_events",
- keys=("room_id", "event_id"),
- values=[
- (
- event.room_id,
- event.event_id,
- )
- for event, ctx in events_and_contexts
- if ctx.partial_state
- ],
- )
+ try:
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="partial_state_events",
+ keys=("room_id", "event_id"),
+ values=[
+ (
+ event.room_id,
+ event.event_id,
+ )
+ for event, ctx in events_and_contexts
+ if ctx.partial_state
+ ],
+ )
+ except self.db_pool.engine.module.IntegrityError:
+ logger.info(
+ "Cannot persist events %s in rooms %s: room has been un-partial stated",
+ [
+ event.event_id
+ for event, ctx in events_and_contexts
+ if ctx.partial_state
+ ],
+ list(
+ {
+ event.room_id
+ for event, ctx in events_and_contexts
+ if ctx.partial_state
+ }
+ ),
+ )
+ raise PartialStateConflictError()
self.db_pool.simple_upsert_many_txn(
txn,
@@ -2296,11 +2381,9 @@ class PersistEventsStore:
self.db_pool.simple_insert_many_txn(
txn,
table="event_edges",
- keys=("event_id", "prev_event_id", "room_id", "is_state"),
+ keys=("event_id", "prev_event_id"),
values=[
- (ev.event_id, e_id, ev.room_id, False)
- for ev in events
- for e_id in ev.prev_event_ids()
+ (ev.event_id, e_id) for ev in events for e_id in ev.prev_event_ids()
],
)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index d5f0059665..6e8aeed7b4 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1,4 +1,4 @@
-# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -64,6 +64,11 @@ class _BackgroundUpdates:
INDEX_STREAM_ORDERING2_TS = "index_stream_ordering2_ts"
REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column"
+ EVENT_EDGES_DROP_INVALID_ROWS = "event_edges_drop_invalid_rows"
+ EVENT_EDGES_REPLACE_INDEX = "event_edges_replace_index"
+
+ EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections"
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _CalculateChainCover:
@@ -177,11 +182,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._purged_chain_cover_index,
)
- # The event_thread_relation background update was replaced with the
- # event_arbitrary_relations one, which handles any relation to avoid
- # needed to potentially crawl the entire events table in the future.
- self.db_pool.updates.register_noop_background_update("event_thread_relation")
-
self.db_pool.updates.register_background_update_handler(
"event_arbitrary_relations",
self._event_arbitrary_relations,
@@ -240,6 +240,26 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
################################################################################
+ self.db_pool.updates.register_background_update_handler(
+ _BackgroundUpdates.EVENT_EDGES_DROP_INVALID_ROWS,
+ self._background_drop_invalid_event_edges_rows,
+ )
+
+ self.db_pool.updates.register_background_index_update(
+ _BackgroundUpdates.EVENT_EDGES_REPLACE_INDEX,
+ index_name="event_edges_event_id_prev_event_id_idx",
+ table="event_edges",
+ columns=["event_id", "prev_event_id"],
+ unique=True,
+ # the old index which just covered event_id is now redundant.
+ replaces_index="ev_edges_id",
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS,
+ self._background_events_populate_state_key_rejections,
+ )
+
async def _background_reindex_fields_sender(
self, progress: JsonDict, batch_size: int
) -> int:
@@ -1290,3 +1310,179 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
return 0
+
+ async def _background_drop_invalid_event_edges_rows(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Drop invalid rows from event_edges
+
+ This only runs for postgres. For SQLite, it all happens synchronously.
+
+ Firstly, drop any rows with is_state=True. These may have been added a long time
+ ago, but they are no longer used.
+
+ We also drop rows that do not correspond to entries in `events`, and add a
+ foreign key.
+ """
+
+ last_event_id = progress.get("last_event_id", "")
+
+ def drop_invalid_event_edges_txn(txn: LoggingTransaction) -> bool:
+ """Returns True if we're done."""
+
+ # first we need to find an endpoint.
+ txn.execute(
+ """
+ SELECT event_id FROM event_edges
+ WHERE event_id > ?
+ ORDER BY event_id
+ LIMIT 1 OFFSET ?
+ """,
+ (last_event_id, batch_size),
+ )
+
+ endpoint = None
+ row = txn.fetchone()
+
+ if row:
+ endpoint = row[0]
+
+ where_clause = "ee.event_id > ?"
+ args = [last_event_id]
+ if endpoint:
+ where_clause += " AND ee.event_id <= ?"
+ args.append(endpoint)
+
+ # now delete any that:
+ # - have is_state=TRUE, or
+ # - do not correspond to a row in `events`
+ txn.execute(
+ f"""
+ DELETE FROM event_edges
+ WHERE event_id IN (
+ SELECT ee.event_id
+ FROM event_edges ee
+ LEFT JOIN events ev USING (event_id)
+ WHERE ({where_clause}) AND
+ (is_state OR ev.event_id IS NULL)
+ )""",
+ args,
+ )
+
+ logger.info(
+ "cleaned up event_edges up to %s: removed %i/%i rows",
+ endpoint,
+ txn.rowcount,
+ batch_size,
+ )
+
+ if endpoint is not None:
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ _BackgroundUpdates.EVENT_EDGES_DROP_INVALID_ROWS,
+ {"last_event_id": endpoint},
+ )
+ return False
+
+ # if that was the final batch, we validate the foreign key.
+ #
+ # The constraint should have been in place and enforced for new rows since
+ # before we started deleting invalid rows, so there's no chance for any
+ # invalid rows to have snuck in the meantime. In other words, this really
+ # ought to succeed.
+ logger.info("cleaned up event_edges; enabling foreign key")
+ txn.execute(
+ "ALTER TABLE event_edges VALIDATE CONSTRAINT event_edges_event_id_fkey"
+ )
+ return True
+
+ done = await self.db_pool.runInteraction(
+ desc="drop_invalid_event_edges", func=drop_invalid_event_edges_txn
+ )
+
+ if done:
+ await self.db_pool.updates._end_background_update(
+ _BackgroundUpdates.EVENT_EDGES_DROP_INVALID_ROWS
+ )
+
+ return batch_size
+
+ async def _background_events_populate_state_key_rejections(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Back-populate `events.state_key` and `events.rejection_reason"""
+
+ min_stream_ordering_exclusive = progress["min_stream_ordering_exclusive"]
+ max_stream_ordering_inclusive = progress["max_stream_ordering_inclusive"]
+
+ def _populate_txn(txn: LoggingTransaction) -> bool:
+ """Returns True if we're done."""
+
+ # first we need to find an endpoint.
+ # we need to find the final row in the batch of batch_size, which means
+ # we need to skip over (batch_size-1) rows and get the next row.
+ txn.execute(
+ """
+ SELECT stream_ordering FROM events
+ WHERE stream_ordering > ? AND stream_ordering <= ?
+ ORDER BY stream_ordering
+ LIMIT 1 OFFSET ?
+ """,
+ (
+ min_stream_ordering_exclusive,
+ max_stream_ordering_inclusive,
+ batch_size - 1,
+ ),
+ )
+
+ endpoint = None
+ row = txn.fetchone()
+ if row:
+ endpoint = row[0]
+
+ where_clause = "stream_ordering > ?"
+ args = [min_stream_ordering_exclusive]
+ if endpoint:
+ where_clause += " AND stream_ordering <= ?"
+ args.append(endpoint)
+
+ # now do the updates.
+ txn.execute(
+ f"""
+ UPDATE events
+ SET state_key = (SELECT state_key FROM state_events se WHERE se.event_id = events.event_id),
+ rejection_reason = (SELECT reason FROM rejections rej WHERE rej.event_id = events.event_id)
+ WHERE ({where_clause})
+ """,
+ args,
+ )
+
+ logger.info(
+ "populated new `events` columns up to %s/%i: updated %i rows",
+ endpoint,
+ max_stream_ordering_inclusive,
+ txn.rowcount,
+ )
+
+ if endpoint is None:
+ # we're done
+ return True
+
+ progress["min_stream_ordering_exclusive"] = endpoint
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS,
+ progress,
+ )
+ return False
+
+ done = await self.db_pool.runInteraction(
+ desc="events_populate_state_key_rejections", func=_populate_txn
+ )
+
+ if done:
+ await self.db_pool.updates._end_background_update(
+ _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS
+ )
+
+ return batch_size
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index b99b107784..9b997c304d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -54,6 +54,7 @@ from synapse.logging.context import (
current_context,
make_deferred_yieldable,
)
+from synapse.logging.opentracing import start_active_span, tag_args, trace
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
@@ -79,7 +80,7 @@ from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.lrucache import AsyncLruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -238,7 +239,9 @@ class EventsWorkerStore(SQLBaseStore):
5 * 60 * 1000,
)
- self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
+ self._get_event_cache: AsyncLruCache[
+ Tuple[str], EventCacheEntry
+ ] = AsyncLruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
)
@@ -292,25 +295,6 @@ class EventsWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
- async def get_received_ts(self, event_id: str) -> Optional[int]:
- """Get received_ts (when it was persisted) for the event.
-
- Raises an exception for unknown events.
-
- Args:
- event_id: The event ID to query.
-
- Returns:
- Timestamp in milliseconds, or None for events that were persisted
- before received_ts was implemented.
- """
- return await self.db_pool.simple_select_one_onecol(
- table="events",
- keyvalues={"event_id": event_id},
- retcol="received_ts",
- desc="get_received_ts",
- )
-
async def have_censored_event(self, event_id: str) -> bool:
"""Check if an event has been censored, i.e. if the content of the event has been erased
from the database due to a redaction.
@@ -447,6 +431,8 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event_id: e for e in events}
+ @trace
+ @tag_args
async def get_events_as_list(
self,
event_ids: Collection[str],
@@ -617,7 +603,11 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
map from event id to result
"""
- event_entry_map = self._get_events_from_cache(
+ # Shortcut: check if we have any events in the *in memory* cache - this function
+ # may be called repeatedly for the same event so at this point we cannot reach
+ # out to any external cache for performance reasons. The external cache is
+ # checked later on in the `get_missing_events_from_cache_or_db` function below.
+ event_entry_map = self._get_events_from_local_cache(
event_ids,
)
@@ -649,7 +639,9 @@ class EventsWorkerStore(SQLBaseStore):
if missing_events_ids:
- async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
+ async def get_missing_events_from_cache_or_db() -> Dict[
+ str, EventCacheEntry
+ ]:
"""Fetches the events in `missing_event_ids` from the database.
Also creates entries in `self._current_event_fetches` to allow
@@ -674,10 +666,18 @@ class EventsWorkerStore(SQLBaseStore):
# the events have been redacted, and if so pulling the redaction event
# out of the database to check it.
#
+ missing_events = {}
try:
- missing_events = await self._get_events_from_db(
+ # Try to fetch from any external cache. We already checked the
+ # in-memory cache above.
+ missing_events = await self._get_events_from_external_cache(
missing_events_ids,
)
+ # Now actually fetch any remaining events from the DB
+ db_missing_events = await self._get_events_from_db(
+ missing_events_ids - missing_events.keys(),
+ )
+ missing_events.update(db_missing_events)
except Exception as e:
with PreserveLoggingContext():
fetching_deferred.errback(e)
@@ -696,7 +696,7 @@ class EventsWorkerStore(SQLBaseStore):
# cancellations, since multiple `_get_events_from_cache_or_db` calls can
# reuse the same fetch.
missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
- get_missing_events_from_db()
+ get_missing_events_from_cache_or_db()
)
event_entry_map.update(missing_events)
@@ -729,15 +729,96 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
- def _invalidate_get_event_cache(self, event_id: str) -> None:
- self._get_event_cache.invalidate((event_id,))
+ def invalidate_get_event_cache_after_txn(
+ self, txn: LoggingTransaction, event_id: str
+ ) -> None:
+ """
+ Prepares a database transaction to invalidate the get event cache for a given
+ event ID when executed successfully. This is achieved by attaching two callbacks
+ to the transaction, one to invalidate the async cache and one for the in memory
+ sync cache (importantly called in that order).
+
+ Arguments:
+ txn: the database transaction to attach the callbacks to
+ event_id: the event ID to be invalidated from caches
+ """
+
+ txn.async_call_after(self._invalidate_async_get_event_cache, event_id)
+ txn.call_after(self._invalidate_local_get_event_cache, event_id)
+
+ async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
+ """
+ Invalidates an event in the asyncronous get event cache, which may be remote.
+
+ Arguments:
+ event_id: the event ID to invalidate
+ """
+
+ await self._get_event_cache.invalidate((event_id,))
+
+ def _invalidate_local_get_event_cache(self, event_id: str) -> None:
+ """
+ Invalidates an event in local in-memory get event caches.
+
+ Arguments:
+ event_id: the event ID to invalidate
+ """
+
+ self._get_event_cache.invalidate_local((event_id,))
self._event_ref.pop(event_id, None)
self._current_event_fetches.pop(event_id, None)
- def _get_events_from_cache(
+ async def _get_events_from_cache(
+ self, events: Iterable[str], update_metrics: bool = True
+ ) -> Dict[str, EventCacheEntry]:
+ """Fetch events from the caches, both in memory and any external.
+
+ May return rejected events.
+
+ Args:
+ events: list of event_ids to fetch
+ update_metrics: Whether to update the cache hit ratio metrics
+ """
+ event_map = self._get_events_from_local_cache(
+ events, update_metrics=update_metrics
+ )
+
+ missing_event_ids = (e for e in events if e not in event_map)
+ event_map.update(
+ await self._get_events_from_external_cache(
+ events=missing_event_ids,
+ update_metrics=update_metrics,
+ )
+ )
+
+ return event_map
+
+ async def _get_events_from_external_cache(
+ self, events: Iterable[str], update_metrics: bool = True
+ ) -> Dict[str, EventCacheEntry]:
+ """Fetch events from any configured external cache.
+
+ May return rejected events.
+
+ Args:
+ events: list of event_ids to fetch
+ update_metrics: Whether to update the cache hit ratio metrics
+ """
+ event_map = {}
+
+ for event_id in events:
+ ret = await self._get_event_cache.get_external(
+ (event_id,), None, update_metrics=update_metrics
+ )
+ if ret:
+ event_map[event_id] = ret
+
+ return event_map
+
+ def _get_events_from_local_cache(
self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, EventCacheEntry]:
- """Fetch events from the caches.
+ """Fetch events from the local, in memory, caches.
May return rejected events.
@@ -749,7 +830,7 @@ class EventsWorkerStore(SQLBaseStore):
for event_id in events:
# First check if it's in the event cache
- ret = self._get_event_cache.get(
+ ret = self._get_event_cache.get_local(
(event_id,), None, update_metrics=update_metrics
)
if ret:
@@ -771,7 +852,7 @@ class EventsWorkerStore(SQLBaseStore):
# We add the entry back into the cache as we want to keep
# recently queried events in the cache.
- self._get_event_cache.set((event_id,), cache_entry)
+ self._get_event_cache.set_local((event_id,), cache_entry)
return event_map
@@ -965,7 +1046,13 @@ class EventsWorkerStore(SQLBaseStore):
}
row_dict = self.db_pool.new_transaction(
- conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
+ conn,
+ "do_fetch",
+ [],
+ [],
+ [],
+ self._fetch_event_rows,
+ events_to_fetch,
)
# We only want to resolve deferreds from the main thread
@@ -1006,23 +1093,42 @@ class EventsWorkerStore(SQLBaseStore):
"""
fetched_event_ids: Set[str] = set()
fetched_events: Dict[str, _EventRow] = {}
- events_to_fetch = event_ids
- while events_to_fetch:
- row_map = await self._enqueue_events(events_to_fetch)
+ async def _fetch_event_ids_and_get_outstanding_redactions(
+ event_ids_to_fetch: Collection[str],
+ ) -> Collection[str]:
+ """
+ Fetch all of the given event_ids and return any associated redaction event_ids
+ that we still need to fetch in the next iteration.
+ """
+ row_map = await self._enqueue_events(event_ids_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids: Set[str] = set()
- for event_id in events_to_fetch:
+ for event_id in event_ids_to_fetch:
row = row_map.get(event_id)
fetched_event_ids.add(event_id)
if row:
fetched_events[event_id] = row
redaction_ids.update(row.redactions)
- events_to_fetch = redaction_ids.difference(fetched_event_ids)
- if events_to_fetch:
- logger.debug("Also fetching redaction events %s", events_to_fetch)
+ event_ids_to_fetch = redaction_ids.difference(fetched_event_ids)
+ return event_ids_to_fetch
+
+ # Grab the initial list of events requested
+ event_ids_to_fetch = await _fetch_event_ids_and_get_outstanding_redactions(
+ event_ids
+ )
+ # Then go and recursively find all of the associated redactions
+ with start_active_span("recursively fetching redactions"):
+ while event_ids_to_fetch:
+ logger.debug("Also fetching redaction events %s", event_ids_to_fetch)
+
+ event_ids_to_fetch = (
+ await _fetch_event_ids_and_get_outstanding_redactions(
+ event_ids_to_fetch
+ )
+ )
# build a map from event_id to EventBase
event_map: Dict[str, EventBase] = {}
@@ -1148,7 +1254,7 @@ class EventsWorkerStore(SQLBaseStore):
event=original_ev, redacted_event=redacted_event
)
- self._get_event_cache.set((event_id,), cache_entry)
+ await self._get_event_cache.set((event_id,), cache_entry)
result_map[event_id] = cache_entry
if not redacted_event:
@@ -1340,6 +1446,8 @@ class EventsWorkerStore(SQLBaseStore):
return {r["event_id"] for r in rows}
+ @trace
+ @tag_args
async def have_seen_events(
self, room_id: str, event_ids: Iterable[str]
) -> Set[str]:
@@ -1382,7 +1490,9 @@ class EventsWorkerStore(SQLBaseStore):
# if the event cache contains the event, obviously we've seen it.
cache_results = {
- (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,))
+ (rid, eid)
+ for (rid, eid) in keys
+ if await self._get_event_cache.contains((eid,))
}
results = dict.fromkeys(cache_results, True)
remaining = [k for k in keys if k not in cache_results]
@@ -1465,7 +1575,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
"""Returns new events, for the Events replication stream
Args:
@@ -1481,10 +1591,11 @@ class EventsWorkerStore(SQLBaseStore):
def get_all_new_forward_event_rows(
txn: LoggingTransaction,
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
+ " e.outlier"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
@@ -1498,7 +1609,8 @@ class EventsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (last_id, current_id, instance_name, limit))
return cast(
- List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
+ txn.fetchall(),
)
return await self.db_pool.runInteraction(
@@ -1507,7 +1619,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_ex_outlier_stream_rows(
self, instance_name: str, last_id: int, current_id: int
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
"""Returns de-outliered events, for the Events replication stream
Args:
@@ -1522,11 +1634,14 @@ class EventsWorkerStore(SQLBaseStore):
def get_ex_outlier_stream_rows_txn(
txn: LoggingTransaction,
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
+ " e.outlier"
" FROM events AS e"
+ # NB: the next line (inner join) is what makes this query different from
+ # get_all_new_forward_event_rows.
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
@@ -1541,7 +1656,8 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, instance_name))
return cast(
- List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
+ txn.fetchall(),
)
return await self.db_pool.runInteraction(
@@ -1995,7 +2111,14 @@ class EventsWorkerStore(SQLBaseStore):
AND room_id = ?
/* Make sure event is not rejected */
AND rejections.event_id IS NULL
- ORDER BY origin_server_ts %s
+ /**
+ * First sort by the message timestamp. If the message timestamps are the
+ * same, we want the message that logically comes "next" (before/after
+ * the given timestamp) based on the DAG and its topological order (`depth`).
+ * Finally, we can tie-break based on when it was received on the server
+ * (`stream_ordering`).
+ */
+ ORDER BY origin_server_ts %s, depth %s, stream_ordering %s
LIMIT 1;
"""
@@ -2014,7 +2137,8 @@ class EventsWorkerStore(SQLBaseStore):
order = "ASC"
txn.execute(
- sql_template % (comparison_operator, order), (timestamp, room_id)
+ sql_template % (comparison_operator, order, order, order),
+ (timestamp, room_id),
)
row = txn.fetchone()
if row:
@@ -2079,14 +2203,92 @@ class EventsWorkerStore(SQLBaseStore):
def _get_partial_state_events_batch_txn(
txn: LoggingTransaction, room_id: str
) -> List[str]:
+ # we want to work through the events from oldest to newest, so
+ # we only want events whose prev_events do *not* have partial state - hence
+ # the 'NOT EXISTS' clause in the below.
+ #
+ # This is necessary because ordering by stream ordering isn't quite enough
+ # to ensure that we work from oldest to newest event (in particular,
+ # if an event is initially persisted as an outlier and later de-outliered,
+ # it can end up with a lower stream_ordering than its prev_events).
+ #
+ # Typically this means we'll only return one event per batch, but that's
+ # hard to do much about.
+ #
+ # See also: https://github.com/matrix-org/synapse/issues/13001
txn.execute(
"""
SELECT event_id FROM partial_state_events AS pse
JOIN events USING (event_id)
- WHERE pse.room_id = ?
+ WHERE pse.room_id = ? AND
+ NOT EXISTS(
+ SELECT 1 FROM event_edges AS ee
+ JOIN partial_state_events AS prev_pse ON (prev_pse.event_id=ee.prev_event_id)
+ WHERE ee.event_id=pse.event_id
+ )
ORDER BY events.stream_ordering
LIMIT 100
""",
(room_id,),
)
return [row[0] for row in txn]
+
+ def mark_event_rejected_txn(
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ rejection_reason: Optional[str],
+ ) -> None:
+ """Mark an event that was previously accepted as rejected, or vice versa
+
+ This can happen, for example, when resyncing state during a faster join.
+
+ Args:
+ txn:
+ event_id: ID of event to update
+ rejection_reason: reason it has been rejected, or None if it is now accepted
+ """
+ if rejection_reason is None:
+ logger.info(
+ "Marking previously-processed event %s as accepted",
+ event_id,
+ )
+ self.db_pool.simple_delete_txn(
+ txn,
+ "rejections",
+ keyvalues={"event_id": event_id},
+ )
+ else:
+ logger.info(
+ "Marking previously-processed event %s as rejected(%s)",
+ event_id,
+ rejection_reason,
+ )
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="rejections",
+ keyvalues={"event_id": event_id},
+ values={
+ "reason": rejection_reason,
+ "last_check": self._clock.time_msec(),
+ },
+ )
+ self.db_pool.simple_update_txn(
+ txn,
+ table="events",
+ keyvalues={"event_id": event_id},
+ updatevalues={"rejection_reason": rejection_reason},
+ )
+
+ self.invalidate_get_event_cache_after_txn(txn, event_id)
+
+ # TODO(faster_joins): invalidate the cache on workers. Ideally we'd just
+ # call '_send_invalidation_to_replication', but we actually need the other
+ # end to call _invalidate_local_get_event_cache() rather than (just)
+ # _get_event_cache.invalidate().
+ #
+ # One solution might be to (somehow) get the workers to call
+ # _invalidate_caches_for_event() (though that will invalidate more than
+ # strictly necessary).
+ #
+ # https://github.com/matrix-org/synapse/issues/12994
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
deleted file mode 100644
index c15a7136b6..0000000000
--- a/synapse/storage/databases/main/group_server.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import TYPE_CHECKING
-
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
-
-class GroupServerStore(SQLBaseStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- # Register a legacy groups background update as a no-op.
- database.updates.register_noop_background_update("local_group_updates_index")
- super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index d028be16de..9b172a64d8 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -37,9 +37,6 @@ from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
-BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
- "media_repository_drop_index_wo_method"
-)
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
"media_repository_drop_index_wo_method_2"
)
@@ -111,13 +108,6 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
unique=True,
)
- # the original impl of _drop_media_index_without_method was broken (see
- # https://github.com/matrix-org/synapse/issues/8649), so we replace the original
- # impl with a no-op and run the fixed migration as
- # media_repository_drop_index_wo_method_2.
- self.db_pool.updates.register_noop_background_update(
- BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
- )
self.db_pool.updates.register_background_update_handler(
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2,
self._drop_media_index_without_method,
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 9a63f953fb..efd136a864 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -66,6 +66,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
"initialise_mau_threepids",
[],
[],
+ [],
self._initialise_reserved_users,
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ba385f9fc4..f6822707e4 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -19,6 +19,8 @@ from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.storage.databases.main.state import StateGroupWorkerStore
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines._base import IsolationLevel
from synapse.types import RoomStreamToken
logger = logging.getLogger(__name__)
@@ -214,10 +216,10 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# Delete all remote non-state events
for table in (
+ "event_edges",
"events",
"event_json",
"event_auth",
- "event_edges",
"event_forward_extremities",
"event_relations",
"event_search",
@@ -302,7 +304,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream(
txn, self.have_seen_event, (room_id, event_id)
)
- self._invalidate_get_event_cache(event_id)
+ self.invalidate_get_event_cache_after_txn(txn, event_id)
logger.info("[purge] done")
@@ -317,11 +319,38 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
Returns:
The list of state groups to delete.
"""
- return await self.db_pool.runInteraction(
- "purge_room", self._purge_room_txn, room_id
+
+ # This first runs the purge transaction with READ_COMMITTED isolation level,
+ # meaning any new rows in the tables will not trigger a serialization error.
+ # We then run the same purge a second time without this isolation level to
+ # purge any of those rows which were added during the first.
+
+ state_groups_to_delete = await self.db_pool.runInteraction(
+ "purge_room",
+ self._purge_room_txn,
+ room_id=room_id,
+ isolation_level=IsolationLevel.READ_COMMITTED,
+ )
+
+ state_groups_to_delete.extend(
+ await self.db_pool.runInteraction(
+ "purge_room",
+ self._purge_room_txn,
+ room_id=room_id,
+ ),
)
+ return state_groups_to_delete
+
def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
+ # This collides with event persistence so we cannot write new events and metadata into
+ # a room while deleting it or this transaction will fail.
+ if isinstance(self.database_engine, PostgresEngine):
+ txn.execute(
+ "SELECT room_version FROM rooms WHERE room_id = ? FOR UPDATE",
+ (room_id,),
+ )
+
# First, fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index d5aefe02b6..5079edd1e0 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,11 +14,23 @@
# limitations under the License.
import abc
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+ cast,
+)
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
-from synapse.push.baserules import list_with_base_rules
+from synapse.push.baserules import FilteredPushRules, PushRule, compile_push_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -50,69 +62,39 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-def _is_experimental_rule_enabled(
- rule_id: str, experimental_config: ExperimentalConfig
-) -> bool:
- """Used by `_load_rules` to filter out experimental rules when they
- have not been enabled.
- """
- if (
- rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
- and not experimental_config.msc3786_enabled
- ):
- return False
- if (
- rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
- and not experimental_config.msc3772_enabled
- ):
- return False
- return True
-
-
def _load_rules(
rawrules: List[JsonDict],
enabled_map: Dict[str, bool],
experimental_config: ExperimentalConfig,
-) -> List[JsonDict]:
- ruleslist = []
- for rawrule in rawrules:
- rule = dict(rawrule)
- rule["conditions"] = db_to_json(rawrule["conditions"])
- rule["actions"] = db_to_json(rawrule["actions"])
- rule["default"] = False
- ruleslist.append(rule)
-
- # We're going to be mutating this a lot, so copy it. We also filter out
- # any experimental default push rules that aren't enabled.
- rules = [
- rule
- for rule in list_with_base_rules(ruleslist)
- if _is_experimental_rule_enabled(rule["rule_id"], experimental_config)
- ]
+) -> FilteredPushRules:
+ """Take the DB rows returned from the DB and convert them into a full
+ `FilteredPushRules` object.
+ """
- for i, rule in enumerate(rules):
- rule_id = rule["rule_id"]
+ ruleslist = [
+ PushRule(
+ rule_id=rawrule["rule_id"],
+ priority_class=rawrule["priority_class"],
+ conditions=db_to_json(rawrule["conditions"]),
+ actions=db_to_json(rawrule["actions"]),
+ )
+ for rawrule in rawrules
+ ]
- if rule_id not in enabled_map:
- continue
- if rule.get("enabled", True) == bool(enabled_map[rule_id]):
- continue
+ push_rules = compile_push_rules(ruleslist)
- # Rules are cached across users.
- rule = dict(rule)
- rule["enabled"] = bool(enabled_map[rule_id])
- rules[i] = rule
+ filtered_rules = FilteredPushRules(push_rules, enabled_map, experimental_config)
- return rules
+ return filtered_rules
# The ABCMeta metaclass ensures that it cannot be instantiated without
# the abstract methods being implemented.
class PushRulesWorkerStore(
ApplicationServiceWorkerStore,
- ReceiptsWorkerStore,
PusherWorkerStore,
RoomMemberWorkerStore,
+ ReceiptsWorkerStore,
EventsWorkerStore,
SQLBaseStore,
metaclass=abc.ABCMeta,
@@ -162,7 +144,7 @@ class PushRulesWorkerStore(
raise NotImplementedError()
@cached(max_entries=5000)
- async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
+ async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
@@ -183,7 +165,6 @@ class PushRulesWorkerStore(
return _load_rules(rows, enabled_map, self.hs.config.experimental)
- @cached(max_entries=5000)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
table="push_rules_enable",
@@ -216,11 +197,11 @@ class PushRulesWorkerStore(
@cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
async def bulk_get_push_rules(
self, user_ids: Collection[str]
- ) -> Dict[str, List[JsonDict]]:
+ ) -> Dict[str, FilteredPushRules]:
if not user_ids:
return {}
- results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
+ raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
@@ -228,25 +209,25 @@ class PushRulesWorkerStore(
iterable=user_ids,
retcols=("*",),
desc="bulk_get_push_rules",
+ batch_size=1000,
)
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
for row in rows:
- results.setdefault(row["user_name"], []).append(row)
+ raw_rules.setdefault(row["user_name"], []).append(row)
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
- for user_id, rules in results.items():
+ results: Dict[str, FilteredPushRules] = {}
+
+ for user_id, rules in raw_rules.items():
results[user_id] = _load_rules(
rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
)
return results
- @cachedList(
- cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids"
- )
async def bulk_get_push_rules_enabled(
self, user_ids: Collection[str]
) -> Dict[str, Dict[str, bool]]:
@@ -261,6 +242,7 @@ class PushRulesWorkerStore(
iterable=user_ids,
retcols=("user_name", "rule_id", "enabled"),
desc="bulk_get_push_rules_enabled",
+ batch_size=1000,
)
for row in rows:
enabled = bool(row["enabled"])
@@ -344,8 +326,8 @@ class PushRuleStore(PushRulesWorkerStore):
user_id: str,
rule_id: str,
priority_class: int,
- conditions: List[Dict[str, str]],
- actions: List[Union[JsonDict, str]],
+ conditions: Sequence[Mapping[str, str]],
+ actions: Sequence[Union[Mapping[str, Any], str]],
before: Optional[str] = None,
after: Optional[str] = None,
) -> None:
@@ -807,7 +789,6 @@ class PushRuleStore(PushRulesWorkerStore):
self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
- txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
txn.call_after(
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
@@ -816,7 +797,7 @@ class PushRuleStore(PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
async def copy_push_rule_from_room_to_room(
- self, new_room_id: str, user_id: str, rule: dict
+ self, new_room_id: str, user_id: str, rule: PushRule
) -> None:
"""Copy a single push rule from one room to another for a specific user.
@@ -826,21 +807,27 @@ class PushRuleStore(PushRulesWorkerStore):
rule: A push rule.
"""
# Create new rule id
- rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+ rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
new_rule_id = rule_id_scope + "/" + new_room_id
+ new_conditions = []
+
# Change room id in each condition
- for condition in rule.get("conditions", []):
+ for condition in rule.conditions:
+ new_condition = condition
if condition.get("key") == "room_id":
- condition["pattern"] = new_room_id
+ new_condition = dict(condition)
+ new_condition["pattern"] = new_room_id
+
+ new_conditions.append(new_condition)
# Add the rule for the new room
await self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
- priority_class=rule["priority_class"],
- conditions=rule["conditions"],
- actions=rule["actions"],
+ priority_class=rule.priority_class,
+ conditions=new_conditions,
+ actions=rule.actions,
)
async def copy_push_rules_from_room_to_room_for_user(
@@ -858,8 +845,11 @@ class PushRuleStore(PushRulesWorkerStore):
user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
- for rule in user_push_rules:
- conditions = rule.get("conditions", [])
+ for rule, enabled in user_push_rules:
+ if not enabled:
+ continue
+
+ conditions = rule.conditions
if any(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 21e954ccc1..124c70ad37 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -26,7 +26,7 @@ from typing import (
cast,
)
-from synapse.api.constants import EduTypes, ReceiptTypes
+from synapse.api.constants import EduTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -36,6 +36,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines._base import IsolationLevel
from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
MultiWriterIdGenerator,
@@ -117,7 +118,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return self._receipts_id_gen.get_current_token()
async def get_last_receipt_event_id_for_user(
- self, user_id: str, room_id: str, receipt_types: Iterable[str]
+ self, user_id: str, room_id: str, receipt_types: Collection[str]
) -> Optional[str]:
"""
Fetch the event ID for the latest receipt in a room with one of the given receipt types.
@@ -125,58 +126,63 @@ class ReceiptsWorkerStore(SQLBaseStore):
Args:
user_id: The user to fetch receipts for.
room_id: The room ID to fetch the receipt for.
- receipt_type: The receipt types to fetch. Earlier receipt types
- are given priority if multiple receipts point to the same event.
+ receipt_type: The receipt types to fetch.
Returns:
The latest receipt, if one exists.
"""
- latest_event_id: Optional[str] = None
- latest_stream_ordering = 0
- for receipt_type in receipt_types:
- result = await self._get_last_receipt_event_id_for_user(
- user_id, room_id, receipt_type
- )
- if result is None:
- continue
- event_id, stream_ordering = result
-
- if latest_event_id is None or latest_stream_ordering < stream_ordering:
- latest_event_id = event_id
- latest_stream_ordering = stream_ordering
+ result = await self.db_pool.runInteraction(
+ "get_last_receipt_event_id_for_user",
+ self.get_last_receipt_for_user_txn,
+ user_id,
+ room_id,
+ receipt_types,
+ )
+ if not result:
+ return None
- return latest_event_id
+ event_id, _ = result
+ return event_id
- @cached()
- async def _get_last_receipt_event_id_for_user(
- self, user_id: str, room_id: str, receipt_type: str
+ def get_last_receipt_for_user_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ room_id: str,
+ receipt_types: Collection[str],
) -> Optional[Tuple[str, int]]:
"""
- Fetch the event ID and stream ordering for the latest receipt.
+ Fetch the event ID and stream_ordering for the latest receipt in a room
+ with one of the given receipt types.
Args:
user_id: The user to fetch receipts for.
room_id: The room ID to fetch the receipt for.
- receipt_type: The receipt type to fetch.
+ receipt_type: The receipt types to fetch.
Returns:
- The event ID and stream ordering of the latest receipt, if one exists;
- otherwise `None`.
+ The event ID and stream ordering of the latest receipt, if one exists.
"""
- sql = """
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "receipt_type", receipt_types
+ )
+
+ sql = f"""
SELECT event_id, stream_ordering
FROM receipts_linearized
INNER JOIN events USING (room_id, event_id)
- WHERE user_id = ?
+ WHERE {clause}
+ AND user_id = ?
AND room_id = ?
- AND receipt_type = ?
+ ORDER BY stream_ordering DESC
+ LIMIT 1
"""
- def f(txn: LoggingTransaction) -> Optional[Tuple[str, int]]:
- txn.execute(sql, (user_id, room_id, receipt_type))
- return cast(Optional[Tuple[str, int]], txn.fetchone())
+ args.extend((user_id, room_id))
+ txn.execute(sql, args)
- return await self.db_pool.runInteraction("get_own_receipt_for_user", f)
+ return cast(Optional[Tuple[str, int]], txn.fetchone())
async def get_receipts_for_user(
self, user_id: str, receipt_types: Iterable[str]
@@ -576,8 +582,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
) -> None:
self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate((room_id,))
- self._get_last_receipt_event_id_for_user.invalidate(
- (user_id, room_id, receipt_type)
+
+ # We use this method to invalidate so that we don't end up with circular
+ # dependencies between the receipts and push action stores.
+ self._attempt_to_invalidate_cache(
+ "get_unread_event_push_actions_by_room_for_user", (room_id,)
)
def process_replication_rows(
@@ -673,17 +682,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
lock=False,
)
- # When updating a local users read receipt, remove any push actions
- # which resulted from the receipt's event and all earlier events.
- if (
- self.hs.is_mine_id(user_id)
- and receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
- and stream_ordering is not None
- ):
- self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
- txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
- )
-
return rx_ts
def _graph_to_linear(
@@ -764,6 +762,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
linearized_event_id,
data,
stream_id=stream_id,
+ # Read committed is actually beneficial here because we check for a receipt with
+ # greater stream order, and checking the very latest data at select time is better
+ # than the data at transaction start time.
+ isolation_level=IsolationLevel.READ_COMMITTED,
)
# If the receipt was older than the currently persisted one, nothing to do.
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 4991360b70..7fb9c801da 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -69,9 +69,9 @@ class TokenLookupResult:
"""
user_id: str
+ token_id: int
is_guest: bool = False
shadow_banned: bool = False
- token_id: Optional[int] = None
device_id: Optional[str] = None
valid_until_ms: Optional[int] = None
token_owner: str = attr.ib()
@@ -1805,21 +1805,10 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
columns=["creation_ts"],
)
- # we no longer use refresh tokens, but it's possible that some people
- # might have a background update queued to build this index. Just
- # clear the background update.
- self.db_pool.updates.register_noop_background_update(
- "refresh_tokens_device_index"
- )
-
self.db_pool.updates.register_background_update_handler(
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
- self.db_pool.updates.register_noop_background_update(
- "user_threepids_grandfather"
- )
-
self.db_pool.updates.register_background_index_update(
"user_external_ids_user_id_idx",
index_name="user_external_ids_user_id_idx",
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b457bc189e..7bd27790eb 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -62,7 +62,6 @@ class RelationsWorkerStore(SQLBaseStore):
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
- aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[StreamToken] = None,
@@ -76,7 +75,6 @@ class RelationsWorkerStore(SQLBaseStore):
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
- aggregation_key: Only fetch events with this aggregation key, if given.
limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`).
@@ -105,10 +103,6 @@ class RelationsWorkerStore(SQLBaseStore):
where_clause.append("type = ?")
where_args.append(event_type)
- if aggregation_key:
- where_clause.append("aggregation_key = ?")
- where_args.append(aggregation_key)
-
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 68d4fc2e64..bef66f1992 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -32,12 +32,17 @@ from typing import (
import attr
-from synapse.api.constants import EventContentFields, EventTypes, JoinRules
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ JoinRules,
+ PublicRoomsFilterFields,
+)
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
@@ -170,7 +175,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
rooms.creator, state.encryption, state.is_federatable AS federatable,
rooms.is_public AS public, state.join_rules, state.guest_access,
state.history_visibility, curr.current_state_events AS state_events,
- state.avatar, state.topic
+ state.avatar, state.topic, state.room_type
FROM rooms
LEFT JOIN room_stats_state state USING (room_id)
LEFT JOIN room_stats_current curr USING (room_id)
@@ -199,10 +204,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
desc="get_public_room_ids",
)
+ def _construct_room_type_where_clause(
+ self, room_types: Union[List[Union[str, None]], None]
+ ) -> Tuple[Union[str, None], List[str]]:
+ if not room_types:
+ return None, []
+ else:
+ # We use None when we want get rooms without a type
+ is_null_clause = ""
+ if None in room_types:
+ is_null_clause = "OR room_type IS NULL"
+ room_types = [value for value in room_types if value is not None]
+
+ list_clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_type", room_types
+ )
+
+ return f"({list_clause} {is_null_clause})", args
+
async def count_public_rooms(
self,
network_tuple: Optional[ThirdPartyInstanceID],
ignore_non_federatable: bool,
+ search_filter: Optional[dict],
) -> int:
"""Counts the number of public rooms as tracked in the room_stats_current
and room_stats_state table.
@@ -210,11 +234,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
Args:
network_tuple
ignore_non_federatable: If true filters out non-federatable rooms
+ search_filter
"""
def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
query_args = []
+ room_type_clause, args = self._construct_room_type_where_clause(
+ search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+ if search_filter
+ else None
+ )
+ room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
+ query_args += args
+
if network_tuple:
if network_tuple.appservice_id:
published_sql = """
@@ -249,6 +282,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
OR join_rules = '{JoinRules.KNOCK_RESTRICTED}'
OR history_visibility = 'world_readable'
)
+ {room_type_clause}
AND joined_members > 0
"""
@@ -347,8 +381,12 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
if ignore_non_federatable:
where_clauses.append("is_federatable")
- if search_filter and search_filter.get("generic_search_term", None):
- search_term = "%" + search_filter["generic_search_term"] + "%"
+ if search_filter and search_filter.get(
+ PublicRoomsFilterFields.GENERIC_SEARCH_TERM, None
+ ):
+ search_term = (
+ "%" + search_filter[PublicRoomsFilterFields.GENERIC_SEARCH_TERM] + "%"
+ )
where_clauses.append(
"""
@@ -365,6 +403,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
search_term.lower(),
]
+ room_type_clause, args = self._construct_room_type_where_clause(
+ search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+ if search_filter
+ else None
+ )
+ if room_type_clause:
+ where_clauses.append(room_type_clause)
+ query_args += args
+
where_clause = ""
if where_clauses:
where_clause = " AND " + " AND ".join(where_clauses)
@@ -373,7 +420,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
sql = f"""
SELECT
room_id, name, topic, canonical_alias, joined_members,
- avatar, history_visibility, guest_access, join_rules
+ avatar, history_visibility, guest_access, join_rules, room_type
FROM (
{published_sql}
) published
@@ -549,7 +596,8 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members,
curr.local_users_in_room, rooms.room_version, rooms.creator,
state.encryption, state.is_federatable, rooms.is_public, state.join_rules,
- state.guest_access, state.history_visibility, curr.current_state_events
+ state.guest_access, state.history_visibility, curr.current_state_events,
+ state.room_type
FROM room_stats_state state
INNER JOIN room_stats_current curr USING (room_id)
INNER JOIN rooms USING (room_id)
@@ -593,12 +641,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
"version": room[5],
"creator": room[6],
"encryption": room[7],
- "federatable": room[8],
- "public": room[9],
+ # room_stats_state.federatable is an integer on sqlite.
+ "federatable": bool(room[8]),
+ # rooms.is_public is an integer on sqlite.
+ "public": bool(room[9]),
"join_rules": room[10],
"guest_access": room[11],
"history_visibility": room[12],
"state_events": room[13],
+ "room_type": room[14],
}
)
@@ -1109,25 +1160,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
return room_servers
async def clear_partial_state_room(self, room_id: str) -> bool:
- # this can race with incoming events, so we watch out for FK errors.
- # TODO(faster_joins): this still doesn't completely fix the race, since the persist process
- # is not atomic. I fear we need an application-level lock.
+ """Clears the partial state flag for a room.
+
+ Args:
+ room_id: The room whose partial state flag is to be cleared.
+
+ Returns:
+ `True` if the partial state flag has been cleared successfully.
+
+ `False` if the partial state flag could not be cleared because the room
+ still contains events with partial state.
+ """
try:
await self.db_pool.runInteraction(
"clear_partial_state_room", self._clear_partial_state_room_txn, room_id
)
return True
- except self.db_pool.engine.module.DatabaseError as e:
- # TODO(faster_joins): how do we distinguish between FK errors and other errors?
- logger.warning(
+ except self.db_pool.engine.module.IntegrityError as e:
+ # Assume that any `IntegrityError`s are due to partial state events.
+ logger.info(
"Exception while clearing lazy partial-state-room %s, retrying: %s",
room_id,
e,
)
return False
- @staticmethod
- def _clear_partial_state_room_txn(txn: LoggingTransaction, room_id: str) -> None:
+ def _clear_partial_state_room_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> None:
DatabasePool.simple_delete_txn(
txn,
table="partial_state_rooms_servers",
@@ -1138,7 +1198,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
table="partial_state_rooms",
keyvalues={"room_id": room_id},
)
+ self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
+ @cached()
async def is_partial_state_room(self, room_id: str) -> bool:
"""Checks if this room has partial state.
@@ -1164,6 +1226,7 @@ class _BackgroundUpdates:
POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2"
REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth"
POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column"
+ ADD_ROOM_TYPE_COLUMN = "add_room_type_column"
_REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
@@ -1198,6 +1261,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
self._background_add_rooms_room_version_column,
)
+ self.db_pool.updates.register_background_update_handler(
+ _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+ self._background_add_room_type_column,
+ )
+
# BG updates to change the type of room_depth.min_depth
self.db_pool.updates.register_background_update_handler(
_BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2,
@@ -1567,6 +1635,69 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return batch_size
+ async def _background_add_room_type_column(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Background update to go and add room_type information to `room_stats_state`
+ table from `event_json` table.
+ """
+
+ last_room_id = progress.get("room_id", "")
+
+ def _background_add_room_type_column_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
+ sql = """
+ SELECT state.room_id, json FROM event_json
+ INNER JOIN current_state_events AS state USING (event_id)
+ WHERE state.room_id > ? AND type = 'm.room.create'
+ ORDER BY state.room_id
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_room_id, batch_size))
+ room_id_to_create_event_results = txn.fetchall()
+
+ new_last_room_id = None
+ for room_id, event_json in room_id_to_create_event_results:
+ event_dict = db_to_json(event_json)
+
+ room_type = event_dict.get("content", {}).get(
+ EventContentFields.ROOM_TYPE, None
+ )
+ if isinstance(room_type, str):
+ self.db_pool.simple_update_txn(
+ txn,
+ table="room_stats_state",
+ keyvalues={"room_id": room_id},
+ updatevalues={"room_type": room_type},
+ )
+
+ new_last_room_id = room_id
+
+ if new_last_room_id is None:
+ return True
+
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+ {"room_id": new_last_room_id},
+ )
+
+ return False
+
+ end = await self.db_pool.runInteraction(
+ "_background_add_room_type_column",
+ _background_add_room_type_column_txn,
+ )
+
+ if end:
+ await self.db_pool.updates._end_background_update(
+ _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN
+ )
+
+ return batch_size
+
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
def __init__(
@@ -1643,9 +1774,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
servers,
)
- @staticmethod
def _store_partial_state_room_txn(
- txn: LoggingTransaction, room_id: str, servers: Collection[str]
+ self, txn: LoggingTransaction, room_id: str, servers: Collection[str]
) -> None:
DatabasePool.simple_insert_txn(
txn,
@@ -1660,6 +1790,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
keys=("room_id", "server_name"),
values=((room_id, s) for s in servers),
)
+ self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
async def maybe_store_room_on_outlier_membership(
self, room_id: str, room_version: RoomVersion
@@ -1875,9 +2006,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+ # We join on room_stats_state despite not using any columns from it
+ # because the join can influence the number of rows returned;
+ # e.g. a room that doesn't have state, maybe because it was deleted.
+ # The query returning the total count should be consistent with
+ # the query returning the results.
sql = """
SELECT COUNT(*) as total_event_reports
FROM event_reports AS er
+ JOIN room_stats_state ON room_stats_state.room_id = er.room_id
{}
""".format(
where_clause
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 31bc8c5601..4f0adb136a 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -21,6 +21,7 @@ from typing import (
FrozenSet,
Iterable,
List,
+ Mapping,
Optional,
Set,
Tuple,
@@ -30,8 +31,6 @@ from typing import (
import attr
from synapse.api.constants import EventTypes, Membership
-from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import (
run_as_background_process,
@@ -56,6 +55,7 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
+from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -184,34 +184,109 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self._check_safe_current_state_events_membership_updated_txn,
)
- @cached(max_entries=100000, iterable=True, prune_unread_entries=False)
+ @cached(max_entries=100000, iterable=True)
async def get_users_in_room(self, room_id: str) -> List[str]:
+ """
+ Returns a list of users in the room sorted by longest in the room first
+ (aka. with the lowest depth). This is done to match the sort in
+ `get_current_hosts_in_room()` and so we can re-use the cache but it's
+ not horrible to have here either.
+ """
+
return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
)
def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
+ """
+ Returns a list of users in the room sorted by longest in the room first
+ (aka. with the lowest depth). This is done to match the sort in
+ `get_current_hosts_in_room()` and so we can re-use the cache but it's
+ not horrible to have here either.
+ """
# If we can assume current_state_events.membership is up to date
# then we can avoid a join, which is a Very Good Thing given how
# frequently this function gets called.
if self._current_state_events_membership_up_to_date:
sql = """
- SELECT state_key FROM current_state_events
- WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
+ SELECT c.state_key FROM current_state_events as c
+ /* Get the depth of the event from the events table */
+ INNER JOIN events AS e USING (event_id)
+ WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ?
+ /* Sorted by lowest depth first */
+ ORDER BY e.depth ASC;
"""
else:
sql = """
- SELECT state_key FROM room_memberships as m
+ SELECT c.state_key FROM room_memberships as m
+ /* Get the depth of the event from the events table */
+ INNER JOIN events AS e USING (event_id)
INNER JOIN current_state_events as c
ON m.event_id = c.event_id
AND m.room_id = c.room_id
AND m.user_id = c.state_key
WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
+ /* Sorted by lowest depth first */
+ ORDER BY e.depth ASC;
"""
txn.execute(sql, (room_id, Membership.JOIN))
return [r[0] for r in txn]
+ @cached()
+ def get_user_in_room_with_profile(
+ self, room_id: str, user_id: str
+ ) -> Dict[str, ProfileInfo]:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="get_user_in_room_with_profile", list_name="user_ids"
+ )
+ async def get_subset_users_in_room_with_profiles(
+ self, room_id: str, user_ids: Collection[str]
+ ) -> Dict[str, ProfileInfo]:
+ """Get a mapping from user ID to profile information for a list of users
+ in a given room.
+
+ The profile information comes directly from this room's `m.room.member`
+ events, and so may be specific to this room rather than part of a user's
+ global profile. To avoid privacy leaks, the profile data should only be
+ revealed to users who are already in this room.
+
+ Args:
+ room_id: The ID of the room to retrieve the users of.
+ user_ids: a list of users in the room to run the query for
+
+ Returns:
+ A mapping from user ID to ProfileInfo.
+ """
+
+ def _get_subset_users_in_room_with_profiles(
+ txn: LoggingTransaction,
+ ) -> Dict[str, ProfileInfo]:
+ clause, ids = make_in_list_sql_clause(
+ self.database_engine, "c.state_key", user_ids
+ )
+
+ sql = """
+ SELECT state_key, display_name, avatar_url FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? AND %s
+ """ % (
+ clause,
+ )
+ txn.execute(sql, (room_id, Membership.JOIN, *ids))
+
+ return {r[0]: ProfileInfo(display_name=r[1], avatar_url=r[2]) for r in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_subset_users_in_room_with_profiles",
+ _get_subset_users_in_room_with_profiles,
+ )
+
@cached(max_entries=100000, iterable=True)
async def get_users_in_room_with_profiles(
self, room_id: str
@@ -228,6 +303,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Returns:
A mapping from user ID to ProfileInfo.
+
+ Preconditions:
+ - There is full state available for the room (it is not partial-stated).
"""
def _get_users_in_room_with_profiles(
@@ -338,6 +416,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached()
+ async def get_number_joined_users_in_room(self, room_id: str) -> int:
+ return await self.db_pool.simple_select_one_onecol(
+ table="current_state_events",
+ keyvalues={"room_id": room_id, "membership": Membership.JOIN},
+ retcol="COUNT(*)",
+ desc="get_number_joined_users_in_room",
+ )
+
+ @cached()
async def get_invited_rooms_for_local_user(
self, user_id: str
) -> List[RoomsForUser]:
@@ -416,6 +503,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id: str,
membership_list: List[str],
) -> List[RoomsForUser]:
+ """Get all the rooms for this *local* user where the membership for this user
+ matches one in the membership list.
+
+ Args:
+ user_id: The user ID.
+ membership_list: A list of synapse.api.constants.Membership
+ values which the user must be in.
+
+ Returns:
+ The RoomsForUser that the user matches the membership types.
+ """
# Paranoia check.
if not self.hs.is_mine_id(user_id):
raise Exception(
@@ -444,6 +542,44 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
+ @cached(iterable=True)
+ async def get_local_users_in_room(self, room_id: str) -> List[str]:
+ """
+ Retrieves a list of the current roommembers who are local to the server.
+ """
+ return await self.db_pool.simple_select_onecol(
+ table="local_current_membership",
+ keyvalues={"room_id": room_id, "membership": Membership.JOIN},
+ retcol="user_id",
+ desc="get_local_users_in_room",
+ )
+
+ async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
+ """
+ Check whether a given local user is currently joined to the given room.
+
+ Returns:
+ A boolean indicating whether the user is currently joined to the room
+
+ Raises:
+ Exeption when called with a non-local user to this homeserver
+ """
+ if not self.hs.is_mine_id(user_id):
+ raise Exception(
+ "Cannot call 'check_local_user_in_room' on "
+ "non-local user %s" % (user_id,),
+ )
+
+ (
+ membership,
+ member_event_id,
+ ) = await self.get_local_current_membership_for_user_in_room(
+ user_id=user_id,
+ room_id=room_id,
+ )
+
+ return membership == Membership.JOIN
+
async def get_local_current_membership_for_user_in_room(
self, user_id: str, room_id: str
) -> Tuple[Optional[str], Optional[str]]:
@@ -476,7 +612,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results_dict.get("membership"), results_dict.get("event_id")
- @cached(max_entries=500000, iterable=True, prune_unread_entries=False)
+ @cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering(
self, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
@@ -647,25 +783,76 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
return frozenset(r.room_id for r in rooms)
- @cached(
- max_entries=500000,
- cache_context=True,
- iterable=True,
- prune_unread_entries=False,
+ @cached(max_entries=10000)
+ async def does_pair_of_users_share_a_room(
+ self, user_id: str, other_user_id: str
+ ) -> bool:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="does_pair_of_users_share_a_room", list_name="other_user_ids"
)
- async def get_users_who_share_room_with_user(
- self, user_id: str, cache_context: _CacheContext
+ async def _do_users_share_a_room(
+ self, user_id: str, other_user_ids: Collection[str]
+ ) -> Mapping[str, Optional[bool]]:
+ """Return mapping from user ID to whether they share a room with the
+ given user.
+
+ Note: `None` and `False` are equivalent and mean they don't share a
+ room.
+ """
+
+ def do_users_share_a_room_txn(
+ txn: LoggingTransaction, user_ids: Collection[str]
+ ) -> Dict[str, bool]:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "state_key", user_ids
+ )
+
+ # This query works by fetching both the list of rooms for the target
+ # user and the set of other users, and then checking if there is any
+ # overlap.
+ sql = f"""
+ SELECT b.state_key
+ FROM (
+ SELECT room_id FROM current_state_events
+ WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ?
+ ) AS a
+ INNER JOIN (
+ SELECT room_id, state_key FROM current_state_events
+ WHERE type = 'm.room.member' AND membership = 'join' AND {clause}
+ ) AS b using (room_id)
+ LIMIT 1
+ """
+
+ txn.execute(sql, (user_id, *args))
+ return {u: True for u, in txn}
+
+ to_return = {}
+ for batch_user_ids in batch_iter(other_user_ids, 1000):
+ res = await self.db_pool.runInteraction(
+ "do_users_share_a_room", do_users_share_a_room_txn, batch_user_ids
+ )
+ to_return.update(res)
+
+ return to_return
+
+ async def do_users_share_a_room(
+ self, user_id: str, other_user_ids: Collection[str]
) -> Set[str]:
+ """Return the set of users who share a room with the first users"""
+
+ user_dict = await self._do_users_share_a_room(user_id, other_user_ids)
+
+ return {u for u, share_room in user_dict.items() if share_room}
+
+ async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]:
"""Returns the set of users who share a room with `user_id`"""
- room_ids = await self.get_rooms_for_user(
- user_id, on_invalidate=cache_context.invalidate
- )
+ room_ids = await self.get_rooms_for_user(user_id)
user_who_share_room = set()
for room_id in room_ids:
- user_ids = await self.get_users_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
+ user_ids = await self.get_users_in_room(room_id)
user_who_share_room.update(user_ids)
return user_who_share_room
@@ -694,161 +881,92 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return shared_room_ids or frozenset()
- async def get_joined_users_from_context(
- self, event: EventBase, context: EventContext
- ) -> Dict[str, ProfileInfo]:
- state_group: Union[object, int] = context.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
+ async def get_joined_user_ids_from_state(
+ self, room_id: str, state: StateMap[str]
+ ) -> Set[str]:
+ """
+ For a given set of state IDs, get a set of user IDs in the room.
- current_state_ids = await context.get_current_state_ids()
- assert current_state_ids is not None
- assert state_group is not None
- return await self._get_joined_users_from_context(
- event.room_id, state_group, current_state_ids, event=event, context=context
- )
+ This method checks the local event cache, before calling
+ `_get_user_ids_from_membership_event_ids` for any uncached events.
+ """
- async def get_joined_users_from_state(
- self, room_id: str, state_entry: "_StateCacheEntry"
- ) -> Dict[str, ProfileInfo]:
- state_group: Union[object, int] = state_entry.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
+ with Measure(self._clock, "get_joined_user_ids_from_state"):
+ users_in_room = set()
+ member_event_ids = [
+ e_id for key, e_id in state.items() if key[0] == EventTypes.Member
+ ]
- assert state_group is not None
- with Measure(self._clock, "get_joined_users_from_state"):
- return await self._get_joined_users_from_context(
- room_id, state_group, state_entry.state, context=state_entry
- )
+ # We check if we have any of the member event ids in the event cache
+ # before we ask the DB
- @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
- async def _get_joined_users_from_context(
- self,
- room_id: str,
- state_group: Union[object, int],
- current_state_ids: StateMap[str],
- cache_context: _CacheContext,
- event: Optional[EventBase] = None,
- context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
- ) -> Dict[str, ProfileInfo]:
- # We don't use `state_group`, it's there so that we can cache based
- # on it. However, it's important that it's never None, since two current_states
- # with a state_group of None are likely to be different.
- assert state_group is not None
-
- users_in_room = {}
- member_event_ids = [
- e_id
- for key, e_id in current_state_ids.items()
- if key[0] == EventTypes.Member
- ]
-
- if context is not None:
- # If we have a context with a delta from a previous state group,
- # check if we also have the result from the previous group in cache.
- # If we do then we can reuse that result and simply update it with
- # any membership changes in `delta_ids`
- if context.prev_group and context.delta_ids:
- prev_res = self._get_joined_users_from_context.cache.get_immediate(
- (room_id, context.prev_group), None
- )
- if prev_res and isinstance(prev_res, dict):
- users_in_room = dict(prev_res)
- member_event_ids = [
- e_id
- for key, e_id in context.delta_ids.items()
- if key[0] == EventTypes.Member
- ]
- for etype, state_key in context.delta_ids:
- if etype == EventTypes.Member:
- users_in_room.pop(state_key, None)
-
- # We check if we have any of the member event ids in the event cache
- # before we ask the DB
-
- # We don't update the event cache hit ratio as it completely throws off
- # the hit ratio counts. After all, we don't populate the cache if we
- # miss it here
- event_map = self._get_events_from_cache(member_event_ids, update_metrics=False)
-
- missing_member_event_ids = []
- for event_id in member_event_ids:
- ev_entry = event_map.get(event_id)
- if ev_entry and not ev_entry.event.rejected_reason:
- if ev_entry.event.membership == Membership.JOIN:
- users_in_room[ev_entry.event.state_key] = ProfileInfo(
- display_name=ev_entry.event.content.get("displayname", None),
- avatar_url=ev_entry.event.content.get("avatar_url", None),
- )
- else:
- missing_member_event_ids.append(event_id)
-
- if missing_member_event_ids:
- event_to_memberships = await self._get_joined_profiles_from_event_ids(
- missing_member_event_ids
+ # We don't update the event cache hit ratio as it completely throws off
+ # the hit ratio counts. After all, we don't populate the cache if we
+ # miss it here
+ event_map = self._get_events_from_local_cache(
+ member_event_ids, update_metrics=False
)
- users_in_room.update(row for row in event_to_memberships.values() if row)
-
- if event is not None and event.type == EventTypes.Member:
- if event.membership == Membership.JOIN:
- if event.event_id in member_event_ids:
- users_in_room[event.state_key] = ProfileInfo(
- display_name=event.content.get("displayname", None),
- avatar_url=event.content.get("avatar_url", None),
+
+ missing_member_event_ids = []
+ for event_id in member_event_ids:
+ ev_entry = event_map.get(event_id)
+ if ev_entry and not ev_entry.event.rejected_reason:
+ if ev_entry.event.membership == Membership.JOIN:
+ users_in_room.add(ev_entry.event.state_key)
+ else:
+ missing_member_event_ids.append(event_id)
+
+ if missing_member_event_ids:
+ event_to_memberships = (
+ await self._get_user_ids_from_membership_event_ids(
+ missing_member_event_ids
)
+ )
+ users_in_room.update(
+ user_id for user_id in event_to_memberships.values() if user_id
+ )
- return users_in_room
+ return users_in_room
- @cached(max_entries=10000)
- def _get_joined_profile_from_event_id(
+ @cached(
+ max_entries=10000,
+ # This name matches the old function that has been replaced - the cache name
+ # is kept here to maintain backwards compatibility.
+ name="_get_joined_profile_from_event_id",
+ )
+ def _get_user_id_from_membership_event_id(
self, event_id: str
) -> Optional[Tuple[str, ProfileInfo]]:
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_joined_profile_from_event_id",
+ cached_method_name="_get_user_id_from_membership_event_id",
list_name="event_ids",
)
- async def _get_joined_profiles_from_event_ids(
+ async def _get_user_ids_from_membership_event_ids(
self, event_ids: Iterable[str]
- ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
+ ) -> Dict[str, Optional[str]]:
"""For given set of member event_ids check if they point to a join
- event and if so return the associated user and profile info.
+ event.
Args:
event_ids: The member event IDs to lookup
Returns:
- Map from event ID to `user_id` and ProfileInfo (or None if not join event).
+ Map from event ID to `user_id`, or None if event is not a join.
"""
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
- retcols=("user_id", "display_name", "avatar_url", "event_id"),
+ retcols=("user_id", "event_id"),
keyvalues={"membership": Membership.JOIN},
- batch_size=500,
- desc="_get_joined_profiles_from_event_ids",
+ batch_size=1000,
+ desc="_get_user_ids_from_membership_event_ids",
)
- return {
- row["event_id"]: (
- row["user_id"],
- ProfileInfo(
- avatar_url=row["avatar_url"], display_name=row["display_name"]
- ),
- )
- for row in rows
- }
+ return {row["event_id"]: row["user_id"] for row in rows}
@cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -894,44 +1012,77 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
@cached(iterable=True, max_entries=10000)
- async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
- """Get current hosts in room based on current state."""
+ async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
+ """
+ Get current hosts in room based on current state.
+
+ The heuristic of sorting by servers who have been in the room the
+ longest is good because they're most likely to have anything we ask
+ about.
+
+ Returns:
+ Returns a list of servers sorted by longest in the room first. (aka.
+ sorted by join with the lowest depth first).
+ """
# First we check if we already have `get_users_in_room` in the cache, as
# we can just calculate result from that
users = self.get_users_in_room.cache.get_immediate(
(room_id,), None, update_metrics=False
)
- if users is not None:
- return {get_domain_from_id(u) for u in users}
-
- if isinstance(self.database_engine, Sqlite3Engine):
+ if users is None and isinstance(self.database_engine, Sqlite3Engine):
# If we're using SQLite then let's just always use
# `get_users_in_room` rather than funky SQL.
users = await self.get_users_in_room(room_id)
- return {get_domain_from_id(u) for u in users}
+
+ if users is not None:
+ # Because `users` is sorted from lowest -> highest depth, the list
+ # of domains will also be sorted that way.
+ domains: List[str] = []
+ # We use a `Set` just for fast lookups
+ domain_set: Set[str] = set()
+ for u in users:
+ domain = get_domain_from_id(u)
+ if domain not in domain_set:
+ domain_set.add(domain)
+ domains.append(domain)
+ return domains
# For PostgreSQL we can use a regex to pull out the domains from the
# joined users in `current_state_events` via regex.
- def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
+ def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> List[str]:
+ # Returns a list of servers currently joined in the room sorted by
+ # longest in the room first (aka. with the lowest depth). The
+ # heuristic of sorting by servers who have been in the room the
+ # longest is good because they're most likely to have anything we
+ # ask about.
sql = """
- SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
- FROM current_state_events
+ SELECT
+ /* Match the domain part of the MXID */
+ substring(c.state_key FROM '@[^:]*:(.*)$') as server_domain
+ FROM current_state_events c
+ /* Get the depth of the event from the events table */
+ INNER JOIN events AS e USING (event_id)
WHERE
- type = 'm.room.member'
- AND membership = 'join'
- AND room_id = ?
+ /* Find any join state events in the room */
+ c.type = 'm.room.member'
+ AND c.membership = 'join'
+ AND c.room_id = ?
+ /* Group all state events from the same domain into their own buckets (groups) */
+ GROUP BY server_domain
+ /* Sorted by lowest depth first */
+ ORDER BY min(e.depth) ASC;
"""
txn.execute(sql, (room_id,))
- return {d for d, in txn}
+ return [d for d, in txn]
return await self.db_pool.runInteraction(
"get_current_hosts_in_room", get_current_hosts_in_room_txn
)
async def get_joined_hosts(
- self, room_id: str, state_entry: "_StateCacheEntry"
+ self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
@@ -944,7 +1095,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
assert state_group is not None
with Measure(self._clock, "get_joined_hosts"):
return await self._get_joined_hosts(
- room_id, state_group, state_entry=state_entry
+ room_id, state_group, state, state_entry=state_entry
)
@cached(num_args=2, max_entries=10000, iterable=True)
@@ -952,6 +1103,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self,
room_id: str,
state_group: Union[object, int],
+ state: StateMap[str],
state_entry: "_StateCacheEntry",
) -> FrozenSet[str]:
# We don't use `state_group`, it's there so that we can cache based on
@@ -1006,12 +1158,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
else:
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
- joined_users = await self.get_joined_users_from_state(
- room_id, state_entry
+ joined_user_ids = await self.get_joined_user_ids_from_state(
+ room_id, state
)
cache.hosts_to_joined_users = {}
- for user_id in joined_users:
+ for user_id in joined_user_ids:
host = intern_string(get_domain_from_id(user_id))
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
@@ -1090,6 +1242,30 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
+ async def is_locally_forgotten_room(self, room_id: str) -> bool:
+ """Returns whether all local users have forgotten this room_id.
+
+ Args:
+ room_id: The room ID to query.
+
+ Returns:
+ Whether the room is forgotten.
+ """
+
+ sql = """
+ SELECT count(*) > 0 FROM local_current_membership
+ INNER JOIN room_memberships USING (room_id, event_id)
+ WHERE
+ room_id = ?
+ AND forgotten = 0;
+ """
+
+ rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+
+ # `count(*)` returns always an integer
+ # If any rows still exist it means someone has not forgotten this room yet
+ return not rows[0][0]
+
async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
"""Get all rooms that the user has ever been in.
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 78e0773b2a..f6e24b68d2 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -113,7 +113,6 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
- EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
EVENT_SEARCH_DELETE_NON_STRINGS = "event_search_sqlite_delete_non_strings"
@@ -132,15 +131,6 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
)
- # we used to have a background update to turn the GIN index into a
- # GIST one; we no longer do that (obviously) because we actually want
- # a GIN index. However, it's possible that some people might still have
- # the background update queued, so we register a handler to clear the
- # background update.
- self.db_pool.updates.register_noop_background_update(
- self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME
- )
-
self.db_pool.updates.register_background_update_handler(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
)
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5e6efbd0fc..0b10af0e58 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -419,15 +419,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# anything that was rejected should have the same state as its
# predecessor.
if context.rejected:
- assert context.state_group == context.state_group_before_event
+ state_group = context.state_group_before_event
+ else:
+ state_group = context.state_group
self.db_pool.simple_update_txn(
txn,
table="event_to_state_groups",
keyvalues={"event_id": event.event_id},
- updatevalues={"state_group": context.state_group},
+ updatevalues={"state_group": state_group},
)
+ # the event may now be rejected where it was not before, or vice versa,
+ # in which case we need to update the rejected flags.
+ if bool(context.rejected) != (event.rejected_reason is not None):
+ self.mark_event_rejected_txn(txn, event.event_id, context.rejected)
+
self.db_pool.simple_delete_one_txn(
txn,
table="partial_state_events",
@@ -435,11 +442,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
# TODO(faster_joins): need to do something about workers here
+ # https://github.com/matrix-org/synapse/issues/12994
txn.call_after(self.is_partial_state_event.invalidate, (event.event_id,))
txn.call_after(
self._get_state_group_for_event.prefill,
(event.event_id,),
- context.state_group,
+ state_group,
)
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index b95dbef678..b4c652acf3 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@
import logging
from enum import Enum
from itertools import chain
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
from typing_extensions import Counter
@@ -120,11 +120,6 @@ class StatsStore(StateDeltasStore):
self.db_pool.updates.register_background_update_handler(
"populate_stats_process_users", self._populate_stats_process_users
)
- # we no longer need to perform clean-up, but we will give ourselves
- # the potential to reintroduce it in the future – so documentation
- # will still encourage the use of this no-op handler.
- self.db_pool.updates.register_noop_background_update("populate_stats_cleanup")
- self.db_pool.updates.register_noop_background_update("populate_stats_prepare")
async def _populate_stats_process_users(
self, progress: JsonDict, batch_size: int
@@ -243,6 +238,7 @@ class StatsStore(StateDeltasStore):
* avatar
* canonical_alias
* guest_access
+ * room_type
A is_federatable key can also be included with a boolean value.
@@ -268,6 +264,7 @@ class StatsStore(StateDeltasStore):
"avatar",
"canonical_alias",
"guest_access",
+ "room_type",
):
field = fields.get(col, sentinel)
if field is not sentinel and (not isinstance(field, str) or "\0" in field):
@@ -300,6 +297,7 @@ class StatsStore(StateDeltasStore):
keyvalues={id_col: id},
retcol="completed_delta_stream_id",
allow_none=True,
+ desc="get_earliest_token_for_stats",
)
async def bulk_update_stats_delta(
@@ -576,7 +574,7 @@ class StatsStore(StateDeltasStore):
state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined]
- room_state = {
+ room_state: Dict[str, Union[None, bool, str]] = {
"join_rules": None,
"history_visibility": None,
"encryption": None,
@@ -585,6 +583,7 @@ class StatsStore(StateDeltasStore):
"avatar": None,
"canonical_alias": None,
"is_federatable": True,
+ "room_type": None,
}
for event in state_event_map.values():
@@ -608,6 +607,9 @@ class StatsStore(StateDeltasStore):
room_state["is_federatable"] = (
event.content.get(EventContentFields.FEDERATE, True) is True
)
+ room_type = event.content.get(EventContentFields.ROOM_TYPE)
+ if isinstance(room_type, str):
+ room_state["room_type"] = room_type
await self.update_room_state(room_id, room_state)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 8e88784d3c..a347430aa7 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -46,16 +46,19 @@ from typing import (
Set,
Tuple,
cast,
+ overload,
)
import attr
from frozendict import frozendict
+from typing_extensions import Literal
from twisted.internet import defer
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.opentracing import trace
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -795,6 +798,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
return RoomStreamToken(topo, stream_ordering)
+ @overload
+ def get_stream_id_for_event_txn(
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ allow_none: Literal[False] = False,
+ ) -> int:
+ ...
+
+ @overload
+ def get_stream_id_for_event_txn(
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ allow_none: bool = False,
+ ) -> Optional[int]:
+ ...
+
def get_stream_id_for_event_txn(
self,
txn: LoggingTransaction,
@@ -1002,8 +1023,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
}
async def get_all_new_events_stream(
- self, from_id: int, current_id: int, limit: int
- ) -> Tuple[int, List[EventBase]]:
+ self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False
+ ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]:
"""Get all new events
Returns all events with from_id < stream_ordering <= current_id.
@@ -1012,19 +1033,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
from_id: the stream_ordering of the last event we processed
current_id: the stream_ordering of the most recently processed event
limit: the maximum number of events to return
+ get_prev_content: whether to fetch previous event content
Returns:
- A tuple of (next_id, events), where `next_id` is the next value to
- pass as `from_id` (it will either be the stream_ordering of the
- last returned event, or, if fewer than `limit` events were found,
- the `current_id`).
+ A tuple of (next_id, events, event_to_received_ts), where `next_id`
+ is the next value to pass as `from_id` (it will either be the
+ stream_ordering of the last returned event, or, if fewer than `limit`
+ events were found, the `current_id`). The `event_to_received_ts` is
+ a dictionary mapping event ID to the event `received_ts`.
"""
def get_all_new_events_stream_txn(
txn: LoggingTransaction,
- ) -> Tuple[int, List[str]]:
+ ) -> Tuple[int, Dict[str, Optional[int]]]:
sql = (
- "SELECT e.stream_ordering, e.event_id"
+ "SELECT e.stream_ordering, e.event_id, e.received_ts"
" FROM events AS e"
" WHERE"
" ? < e.stream_ordering AND e.stream_ordering <= ?"
@@ -1039,15 +1062,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if len(rows) == limit:
upper_bound = rows[-1][0]
- return upper_bound, [row[1] for row in rows]
+ event_to_received_ts: Dict[str, Optional[int]] = {
+ row[1]: row[2] for row in rows
+ }
+ return upper_bound, event_to_received_ts
- upper_bound, event_ids = await self.db_pool.runInteraction(
+ upper_bound, event_to_received_ts = await self.db_pool.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
)
- events = await self.get_events_as_list(event_ids)
+ events = await self.get_events_as_list(
+ event_to_received_ts.keys(),
+ get_prev_content=get_prev_content,
+ )
- return upper_bound, events
+ return upper_bound, events, event_to_received_ts
async def get_federation_out_pos(self, typ: str) -> int:
if self._need_to_reset_federation_stream_positions:
@@ -1318,6 +1347,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, next_token
+ @trace
async def paginate_room_events(
self,
room_id: str,
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index fa9eadaca7..a7fcc564a9 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -24,6 +24,7 @@ from synapse.storage.database import (
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
from synapse.types import MutableStateMap, StateMap
+from synapse.util.caches import intern_string
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -136,7 +137,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
txn.execute(sql % (where_clause,), args)
for row in txn:
typ, state_key, event_id = row
- key = (typ, state_key)
+ key = (intern_string(typ), intern_string(state_key))
results[group][key] = event_id
else:
max_entries_returned = state_filter.max_entries_returned()
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 609a2b88bf..bb64543c1f 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -202,7 +202,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
requests state from the cache, if False we need to query the DB for the
missing state.
"""
- cache_entry = cache.get(group)
+ # If we are asked explicitly for a subset of keys, we only ask for those
+ # from the cache. This ensures that the `DictionaryCache` can make
+ # better decisions about what to cache and what to expire.
+ dict_keys = None
+ if not state_filter.has_wildcards():
+ dict_keys = state_filter.concrete_types()
+
+ cache_entry = cache.get(group, dict_keys=dict_keys)
state_dict_ids = cache_entry.value
if cache_entry.full or state_filter.is_full():
@@ -400,14 +407,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[StateMap[str]],
- current_state_ids: StateMap[str],
+ current_state_ids: Optional[StateMap[str]],
) -> int:
"""Store a new set of state, returning a newly assigned state group.
+ At least one of `current_state_ids` and `prev_group` must be provided. Whenever
+ `prev_group` is not None, `delta_ids` must also not be None.
+
Args:
event_id: The event ID for which the state was calculated
room_id
- prev_group: A previous state group for the room, optional.
+ prev_group: A previous state group for the room.
delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
@@ -418,10 +428,41 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
The state group ID
"""
- def _store_state_group_txn(txn: LoggingTransaction) -> int:
- if current_state_ids is None:
- # AFAIK, this can never happen
- raise Exception("current_state_ids cannot be None")
+ if prev_group is None and current_state_ids is None:
+ raise Exception("current_state_ids and prev_group can't both be None")
+
+ if prev_group is not None and delta_ids is None:
+ raise Exception("delta_ids is None when prev_group is not None")
+
+ def insert_delta_group_txn(
+ txn: LoggingTransaction, prev_group: int, delta_ids: StateMap[str]
+ ) -> Optional[int]:
+ """Try and persist the new group as a delta.
+
+ Requires that we have the state as a delta from a previous state group.
+
+ Returns:
+ The state group if successfully created, or None if the state
+ needs to be persisted as a full state.
+ """
+ is_in_db = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (prev_group,)
+ )
+
+ # if the chain of state group deltas is going too long, we fall back to
+ # persisting a complete state group.
+ potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+ if potential_hops >= MAX_STATE_DELTA_HOPS:
+ return None
state_group = self._state_group_seq_gen.get_next_id_txn(txn)
@@ -431,51 +472,45 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
values={"id": state_group, "room_id": room_id, "event_id": event_id},
)
- # We persist as a delta if we can, while also ensuring the chain
- # of deltas isn't tooo long, as otherwise read performance degrades.
- if prev_group:
- is_in_db = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="state_groups",
- keyvalues={"id": prev_group},
- retcol="id",
- allow_none=True,
- )
- if not is_in_db:
- raise Exception(
- "Trying to persist state with unpersisted prev_group: %r"
- % (prev_group,)
- )
-
- potential_hops = self._count_state_group_hops_txn(txn, prev_group)
- if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
- assert delta_ids is not None
-
- self.db_pool.simple_insert_txn(
- txn,
- table="state_group_edges",
- values={"state_group": state_group, "prev_state_group": prev_group},
- )
+ self.db_pool.simple_insert_txn(
+ txn,
+ table="state_group_edges",
+ values={"state_group": state_group, "prev_state_group": prev_group},
+ )
- self.db_pool.simple_insert_many_txn(
- txn,
- table="state_groups_state",
- keys=("state_group", "room_id", "type", "state_key", "event_id"),
- values=[
- (state_group, room_id, key[0], key[1], state_id)
- for key, state_id in delta_ids.items()
- ],
- )
- else:
- self.db_pool.simple_insert_many_txn(
- txn,
- table="state_groups_state",
- keys=("state_group", "room_id", "type", "state_key", "event_id"),
- values=[
- (state_group, room_id, key[0], key[1], state_id)
- for key, state_id in current_state_ids.items()
- ],
- )
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
+ values=[
+ (state_group, room_id, key[0], key[1], state_id)
+ for key, state_id in delta_ids.items()
+ ],
+ )
+
+ return state_group
+
+ def insert_full_state_txn(
+ txn: LoggingTransaction, current_state_ids: StateMap[str]
+ ) -> int:
+ """Persist the full state, returning the new state group."""
+ state_group = self._state_group_seq_gen.get_next_id_txn(txn)
+
+ self.db_pool.simple_insert_txn(
+ txn,
+ table="state_groups",
+ values={"id": state_group, "room_id": room_id, "event_id": event_id},
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
+ values=[
+ (state_group, room_id, key[0], key[1], state_id)
+ for key, state_id in current_state_ids.items()
+ ],
+ )
# Prefill the state group caches with this group.
# It's fine to use the sequence like this as the state group map
@@ -491,7 +526,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
self._state_group_members_cache.update,
self._state_group_members_cache.sequence,
key=state_group,
- value=dict(current_member_state_ids),
+ value=current_member_state_ids,
)
current_non_member_state_ids = {
@@ -503,13 +538,35 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
self._state_group_cache.update,
self._state_group_cache.sequence,
key=state_group,
- value=dict(current_non_member_state_ids),
+ value=current_non_member_state_ids,
)
return state_group
+ if prev_group is not None:
+ state_group = await self.db_pool.runInteraction(
+ "store_state_group.insert_delta_group",
+ insert_delta_group_txn,
+ prev_group,
+ delta_ids,
+ )
+ if state_group is not None:
+ return state_group
+
+ # We're going to persist the state as a complete group rather than
+ # a delta, so first we need to ensure we have loaded the state map
+ # from the database.
+ if current_state_ids is None:
+ assert prev_group is not None
+ assert delta_ids is not None
+ groups = await self._get_state_for_groups([prev_group])
+ current_state_ids = dict(groups[prev_group])
+ current_state_ids.update(delta_ids)
+
return await self.db_pool.runInteraction(
- "store_state_group", _store_state_group_txn
+ "store_state_group.insert_full_state",
+ insert_full_state_txn,
+ current_state_ids,
)
async def purge_unreferenced_state_groups(
|