diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index a2f8310388..e30f9c76d4 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -80,6 +80,10 @@ class SQLBaseStore(metaclass=ABCMeta):
)
self._attempt_to_invalidate_cache("get_local_users_in_room", (room_id,))
+ # There's no easy way of invalidating this cache for just the users
+ # that have changed, so we just clear the entire thing.
+ self._attempt_to_invalidate_cache("does_pair_of_users_share_a_room", None)
+
for user_id in members_changed:
self._attempt_to_invalidate_cache(
"get_user_in_room_with_profile", (room_id, user_id)
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index cf98b0ab48..dad3731b9b 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -45,8 +45,14 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.logging import opentracing
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import (
+ SynapseTags,
+ active_span,
+ set_tag,
+ start_active_span_follows_from,
+ trace,
+)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.controllers.state import StateStorageController
from synapse.storage.databases import Databases
@@ -223,7 +229,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
queue.append(end_item)
# also add our active opentracing span to the item so that we get a link back
- span = opentracing.active_span()
+ span = active_span()
if span:
end_item.parent_opentracing_span_contexts.append(span.context)
@@ -234,7 +240,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
res = await make_deferred_yieldable(end_item.deferred.observe())
# add another opentracing span which links to the persist trace.
- with opentracing.start_active_span_follows_from(
+ with start_active_span_follows_from(
f"{task.name}_complete", (end_item.opentracing_span_context,)
):
pass
@@ -266,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
queue = self._get_drainining_queue(room_id)
for item in queue:
try:
- with opentracing.start_active_span_follows_from(
+ with start_active_span_follows_from(
item.task.name,
item.parent_opentracing_span_contexts,
inherit_force_tracing=True,
@@ -355,7 +361,7 @@ class EventsPersistenceStorageController:
f"Found an unexpected task type in event persistence queue: {task}"
)
- @opentracing.trace
+ @trace
async def persist_events(
self,
events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
@@ -380,9 +386,21 @@ class EventsPersistenceStorageController:
PartialStateConflictError: if attempting to persist a partial state event in
a room that has been un-partial stated.
"""
+ event_ids: List[str] = []
partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
+ event_ids.append(event.event_id)
+
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+ str(event_ids),
+ )
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+ str(len(event_ids)),
+ )
+ set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
async def enqueue(
item: Tuple[str, List[Tuple[EventBase, EventContext]]]
@@ -418,7 +436,7 @@ class EventsPersistenceStorageController:
self.main_store.get_room_max_token(),
)
- @opentracing.trace
+ @trace
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index e08f956e6e..f9ffd0e29e 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -29,6 +29,8 @@ from typing import (
from synapse.api.constants import EventTypes
from synapse.events import EventBase
+from synapse.logging.opentracing import tag_args, trace
+from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker,
@@ -82,13 +84,15 @@ class StateStorageController:
return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids(
- self, _room_id: str, event_ids: Collection[str]
+ self, _room_id: str, event_ids: Collection[str], await_full_state: bool = True
) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id: id of the room for these events
event_ids: ids of the events
+ await_full_state: if `True`, will block if we do not yet have complete
+ state at these events.
Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id)
@@ -100,7 +104,9 @@ class StateStorageController:
if not event_ids:
return {}
- event_to_groups = await self.get_state_group_for_events(event_ids)
+ event_to_groups = await self.get_state_group_for_events(
+ event_ids, await_full_state=await_full_state
+ )
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -175,6 +181,7 @@ class StateStorageController:
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
+ @trace
async def get_state_for_events(
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[EventBase]]:
@@ -221,10 +228,13 @@ class StateStorageController:
return {event: event_to_state[event] for event in event_ids}
+ @trace
+ @tag_args
async def get_state_ids_for_events(
self,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
@@ -233,6 +243,9 @@ class StateStorageController:
Args:
event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at these events and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
Returns:
A dict from event_id -> (type, state_key) -> event_id
@@ -241,8 +254,12 @@ class StateStorageController:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
- await_full_state = True
- if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+ if (
+ await_full_state
+ and state_filter
+ and not state_filter.must_await_full_state(self._is_mine_id)
+ ):
+ # Full state is not required if the state filter is restrictive enough.
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
@@ -283,8 +300,12 @@ class StateStorageController:
)
return state_map[event_id]
+ @trace
async def get_state_ids_for_event(
- self, event_id: str, state_filter: Optional[StateFilter] = None
+ self,
+ event_id: str,
+ state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
@@ -292,6 +313,9 @@ class StateStorageController:
Args:
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the event and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
Returns:
A dict from (type, state_key) -> state_event_id
@@ -301,7 +325,9 @@ class StateStorageController:
outlier or is unknown)
"""
state_map = await self.get_state_ids_for_events(
- [event_id], state_filter or StateFilter.all()
+ [event_id],
+ state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
)
return state_map[event_id]
@@ -323,6 +349,8 @@ class StateStorageController:
groups, state_filter or StateFilter.all()
)
+ @trace
+ @tag_args
async def get_state_group_for_events(
self,
event_ids: Collection[str],
@@ -334,6 +362,10 @@ class StateStorageController:
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
state at these events.
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie. they are outliers or unknown)
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)
@@ -460,6 +492,7 @@ class StateStorageController:
prev_stream_id, max_stream_id
)
+ @trace
async def get_current_state(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
@@ -493,3 +526,15 @@ class StateStorageController:
await self._partial_state_room_tracker.await_full_state(room_id)
return await self.stores.main.get_current_hosts_in_room(room_id)
+
+ async def get_users_in_room_with_profiles(
+ self, room_id: str
+ ) -> Dict[str, ProfileInfo]:
+ """
+ Get the current users in the room with their profiles.
+ If the room is currently partial-stated, this will block until the room has
+ full state.
+ """
+ await self._partial_state_room_tracker.await_full_state(room_id)
+
+ return await self.stores.main.get_users_in_room_with_profiles(room_id)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ea672ff89e..b394a6658b 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -39,7 +39,7 @@ from typing import (
)
import attr
-from prometheus_client import Histogram
+from prometheus_client import Counter, Histogram
from typing_extensions import Concatenate, Literal, ParamSpec
from twisted.enterprise import adbapi
@@ -76,7 +76,8 @@ perf_logger = logging.getLogger("synapse.storage.TIME")
sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
-sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
+sql_txn_count = Counter("synapse_storage_transaction_time_count", "sec", ["desc"])
+sql_txn_duration = Counter("synapse_storage_transaction_time_sum", "sec", ["desc"])
# Unique indexes which have been added in background updates. Maps from table name
@@ -795,7 +796,8 @@ class DatabasePool:
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, duration)
- sql_txn_timer.labels(desc).observe(duration)
+ sql_txn_count.labels(desc).inc(1)
+ sql_txn_duration.labels(desc).inc(duration)
async def runInteraction(
self,
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index a3d31d3737..4dccbb732a 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -24,9 +24,9 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.stats import UserSortOrder
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -149,31 +149,6 @@ class DataStore(
],
)
- self._cache_id_gen: Optional[MultiWriterIdGenerator]
- if isinstance(self.database_engine, PostgresEngine):
- # We set the `writers` to an empty list here as we don't care about
- # missing updates over restarts, as we'll not have anything in our
- # caches to invalidate. (This reduces the amount of writes to the DB
- # that happen).
- self._cache_id_gen = MultiWriterIdGenerator(
- db_conn,
- database,
- stream_name="caches",
- instance_name=hs.get_instance_name(),
- tables=[
- (
- "cache_invalidation_stream_by_instance",
- "instance_name",
- "stream_id",
- )
- ],
- sequence_name="cache_invalidation_stream_seq",
- writers=[],
- )
-
- else:
- self._cache_id_gen = None
-
super().__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 9af9f4f18e..c38b8a9e5a 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -650,9 +650,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
txn, self.get_account_data_for_room, (user_id,)
)
self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
- self._invalidate_cache_and_stream(
- txn, self.get_push_rules_enabled_for_user, (user_id,)
- )
# This user might be contained in the ignored_by cache for other users,
# so we have to invalidate it all.
self._invalidate_all_cache_and_stream(txn, self.ignored_by)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2367ddeea3..12e9a42382 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -32,6 +32,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter
@@ -65,6 +66,31 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
psql_only=True, # The table is only on postgres DBs.
)
+ self._cache_id_gen: Optional[MultiWriterIdGenerator]
+ if isinstance(self.database_engine, PostgresEngine):
+ # We set the `writers` to an empty list here as we don't care about
+ # missing updates over restarts, as we'll not have anything in our
+ # caches to invalidate. (This reduces the amount of writes to the DB
+ # that happen).
+ self._cache_id_gen = MultiWriterIdGenerator(
+ db_conn,
+ database,
+ stream_name="caches",
+ instance_name=hs.get_instance_name(),
+ tables=[
+ (
+ "cache_invalidation_stream_by_instance",
+ "instance_name",
+ "stream_id",
+ )
+ ],
+ sequence_name="cache_invalidation_stream_seq",
+ writers=[],
+ )
+
+ else:
+ self._cache_id_gen = None
+
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 422e0e65ca..73c95ffb6f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -436,7 +436,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
(user_id, device_id), None
)
- set_tag("last_deleted_stream_id", last_deleted_stream_id)
+ set_tag("last_deleted_stream_id", str(last_deleted_stream_id))
if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 7a6ed332aa..ca0fe8c4be 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -706,8 +706,8 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
else:
results[user_id] = await self.get_cached_devices_for_user(user_id)
- set_tag("in_cache", results)
- set_tag("not_in_cache", user_ids_not_in_cache)
+ set_tag("in_cache", str(results))
+ set_tag("not_in_cache", str(user_ids_not_in_cache))
return user_ids_not_in_cache, results
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 60f622ad71..46c0d06157 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -146,7 +146,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
key data. The key data will be a dict in the same format as the
DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
"""
- set_tag("query_list", query_list)
+ set_tag("query_list", str(query_list))
if not query_list:
return {}
@@ -418,7 +418,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
- set_tag("new_keys", new_keys)
+ set_tag("new_keys", str(new_keys))
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
@@ -1161,7 +1161,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
- set_tag("device_keys", device_keys)
+ set_tag("device_keys", str(device_keys))
old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index eec55b6478..c836078da6 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -33,6 +33,7 @@ from synapse.api.constants import MAX_DEPTH, EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.events import EventBase, make_event_from_dict
+from synapse.logging.opentracing import tag_args, trace
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
@@ -126,6 +127,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
return await self.get_events_as_list(event_ids)
+ @trace
+ @tag_args
async def get_auth_chain_ids(
self,
room_id: str,
@@ -709,6 +712,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
+ @trace
+ @tag_args
async def get_oldest_event_ids_with_depth_in_room(
self, room_id: str
) -> List[Tuple[str, int]]:
@@ -767,6 +772,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)
+ @trace
async def get_insertion_event_backward_extremities_in_room(
self, room_id: str
) -> List[Tuple[str, int]]:
@@ -1339,6 +1345,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
event_results.reverse()
return event_results
+ @trace
+ @tag_args
async def get_successor_events(self, event_id: str) -> List[str]:
"""Fetch all events that have the given event as a prev event
@@ -1375,6 +1383,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
_delete_old_forward_extrem_cache_txn,
)
+ @trace
async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None:
await self.db_pool.simple_upsert(
table="insertion_event_extremities",
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index dd2627037c..eabf9c9739 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -12,14 +12,85 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+"""Responsible for storing and fetching push actions / notifications.
+
+There are two main uses for push actions:
+ 1. Sending out push to a user's device; and
+ 2. Tracking per-room per-user notification counts (used in sync requests).
+
+For the former we simply use the `event_push_actions` table, which contains all
+the calculated actions for a given user (which were calculated by the
+`BulkPushRuleEvaluator`).
+
+For the latter we could simply count the number of rows in `event_push_actions`
+table for a given room/user, but in practice this is *very* heavyweight when
+there were a large number of notifications (due to e.g. the user never reading a
+room). Plus, keeping all push actions indefinitely uses a lot of disk space.
+
+To fix these issues, we add a new table `event_push_summary` that tracks
+per-user per-room counts of all notifications that happened before a stream
+ordering S. Thus, to get the notification count for a user / room we can simply
+query a single row in `event_push_summary` and count the number of rows in
+`event_push_actions` with a stream ordering larger than S (and as long as S is
+"recent", the number of rows needing to be scanned will be small).
+
+The `event_push_summary` table is updated via a background job that periodically
+chooses a new stream ordering S' (usually the latest stream ordering), counts
+all notifications in `event_push_actions` between the existing S and S', and
+adds them to the existing counts in `event_push_summary`.
+
+This allows us to delete old rows from `event_push_actions` once those rows have
+been counted and added to `event_push_summary` (we call this process
+"rotation").
+
+
+We need to handle when a user sends a read receipt to the room. Again this is
+done as a background process. For each receipt we clear the row in
+`event_push_summary` and count the number of notifications in
+`event_push_actions` that happened after the receipt but before S, and insert
+that count into `event_push_summary` (If the receipt happened *after* S then we
+simply clear the `event_push_summary`.)
+
+Note that its possible that if the read receipt is for an old event the relevant
+`event_push_actions` rows will have been rotated and we get the wrong count
+(it'll be too low). We accept this as a rare edge case that is unlikely to
+impact the user much (since the vast majority of read receipts will be for the
+latest event).
+
+The last complication is to handle the race where we request the notifications
+counts after a user sends a read receipt into the room, but *before* the
+background update handles the receipt (without any special handling the counts
+would be outdated). We fix this by including in `event_push_summary` the read
+receipt we used when updating `event_push_summary`, and every time we query the
+table we check if that matches the most recent read receipt in the room. If yes,
+continue as above, if not we simply query the `event_push_actions` table
+directly.
+
+Since read receipts are almost always for recent events, scanning the
+`event_push_actions` table in this case is unlikely to be a problem. Even if it
+is a problem, it is temporary until the background job handles the new read
+receipt.
+"""
+
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
import attr
from synapse.api.constants import ReceiptTypes
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
@@ -93,7 +164,9 @@ class NotifCounts:
highlight_count: int = 0
-def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
+def _serialize_action(
+ actions: Collection[Union[Mapping, str]], is_highlight: bool
+) -> str:
"""Custom serializer for actions. This allows us to "compress" common actions.
We use the fact that most users have the same actions for notifs (and for
@@ -166,7 +239,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: str,
) -> NotifCounts:
"""Get the notification count, the highlight count and the unread message count
- for a given user in a given room after the given read receipt.
+ for a given user in a given room after their latest read receipt.
Note that this function assumes the user to be a current member of the room,
since it's either called by the sync handler to handle joined room entries, or by
@@ -177,9 +250,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: The user to retrieve the counts for.
Returns
- A dict containing the counts mentioned earlier in this docstring,
- respectively under the keys "notify_count", "highlight_count" and
- "unread_count".
+ A NotifCounts object containing the notification count, the highlight count
+ and the unread message count.
"""
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
@@ -194,20 +266,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
room_id: str,
user_id: str,
) -> NotifCounts:
+ # Get the stream ordering of the user's latest receipt in the room.
result = self.get_last_receipt_for_user_txn(
txn,
user_id,
room_id,
- receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+ receipt_types=(
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.UNSTABLE_READ_PRIVATE,
+ ),
)
- stream_ordering = None
if result:
_, stream_ordering = result
- if stream_ordering is None:
- # Either last_read_event_id is None, or it's an event we don't have (e.g.
- # because it's been purged), in which case retrieve the stream ordering for
+ else:
+ # If the user has no receipts in the room, retrieve the stream ordering for
# the latest membership event from this user in this room (which we assume is
# a join).
event_id = self.db_pool.simple_select_one_onecol_txn(
@@ -224,10 +299,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
def _get_unread_counts_by_pos_txn(
- self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ user_id: str,
+ receipt_stream_ordering: int,
) -> NotifCounts:
"""Get the number of unread messages for a user/room that have happened
since the given stream ordering.
+
+ Args:
+ txn: The database transaction.
+ room_id: The room ID to get unread counts for.
+ user_id: The user ID to get unread counts for.
+ receipt_stream_ordering: The stream ordering of the user's latest
+ receipt in the room. If there are no receipts, the stream ordering
+ of the user's join event.
+
+ Returns
+ A NotifCounts object containing the notification count, the highlight count
+ and the unread message count.
"""
counts = NotifCounts()
@@ -255,7 +346,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
OR last_receipt_stream_ordering = ?
)
""",
- (room_id, user_id, stream_ordering, stream_ordering),
+ (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering),
)
row = txn.fetchone()
@@ -265,7 +356,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
counts.notify_count += row[1]
counts.unread_count += row[2]
- # Next we need to count highlights, which aren't summarized
+ # Next we need to count highlights, which aren't summarised
sql = """
SELECT COUNT(*) FROM event_push_actions
WHERE user_id = ?
@@ -273,17 +364,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
AND stream_ordering > ?
AND highlight = 1
"""
- txn.execute(sql, (user_id, room_id, stream_ordering))
+ txn.execute(sql, (user_id, room_id, receipt_stream_ordering))
row = txn.fetchone()
if row:
counts.highlight_count += row[0]
# Finally we need to count push actions that aren't included in the
- # summary returned above, e.g. recent events that haven't been
- # summarized yet, or the summary is empty due to a recent read receipt.
- stream_ordering = max(stream_ordering, summary_stream_ordering)
+ # summary returned above. This might be due to recent events that haven't
+ # been summarised yet or the summary is out of date due to a recent read
+ # receipt.
+ start_unread_stream_ordering = max(
+ receipt_stream_ordering, summary_stream_ordering
+ )
notify_count, unread_count = self._get_notif_unread_count_for_user_room(
- txn, room_id, user_id, stream_ordering
+ txn, room_id, user_id, start_unread_stream_ordering
)
counts.notify_count += notify_count
@@ -304,6 +398,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
Does not consult `event_push_summary` table, which may include push
actions that have been deleted from `event_push_actions` table.
+
+ Args:
+ txn: The database transaction.
+ room_id: The room ID to get unread counts for.
+ user_id: The user ID to get unread counts for.
+ stream_ordering: The (exclusive) minimum stream ordering to consider.
+ max_stream_ordering: The (inclusive) maximum stream ordering to consider.
+ If this is not given, then no maximum is applied.
+
+ Return:
+ A tuple of the notif count and unread count in the given range.
"""
# If there have been no events in the room since the stream ordering,
@@ -376,6 +481,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will be ordered by ascending stream_ordering.
The list will have between 0~limit entries.
"""
+
# find rooms that have a read receipt in them and return the next
# push actions
def get_after_receipt(
@@ -383,28 +489,41 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
) -> List[Tuple[str, str, int, str, bool]]:
# find rooms that have a read receipt in them and return the next
# push actions
- sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " ep.highlight "
- " FROM ("
- " SELECT room_id,"
- " MAX(stream_ordering) as stream_ordering"
- " FROM events"
- " INNER JOIN receipts_linearized USING (room_id, event_id)"
- " WHERE receipt_type = 'm.read' AND user_id = ?"
- " GROUP BY room_id"
- ") AS rl,"
- " event_push_actions AS ep"
- " WHERE"
- " ep.room_id = rl.room_id"
- " AND ep.stream_ordering > rl.stream_ordering"
- " AND ep.user_id = ?"
- " AND ep.stream_ordering > ?"
- " AND ep.stream_ordering <= ?"
- " AND ep.notif = 1"
- " ORDER BY ep.stream_ordering ASC LIMIT ?"
+
+ receipt_types_clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.UNSTABLE_READ_PRIVATE,
+ ),
+ )
+
+ sql = f"""
+ SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+ ep.highlight
+ FROM (
+ SELECT room_id,
+ MAX(stream_ordering) as stream_ordering
+ FROM events
+ INNER JOIN receipts_linearized USING (room_id, event_id)
+ WHERE {receipt_types_clause} AND user_id = ?
+ GROUP BY room_id
+ ) AS rl,
+ event_push_actions AS ep
+ WHERE
+ ep.room_id = rl.room_id
+ AND ep.stream_ordering > rl.stream_ordering
+ AND ep.user_id = ?
+ AND ep.stream_ordering > ?
+ AND ep.stream_ordering <= ?
+ AND ep.notif = 1
+ ORDER BY ep.stream_ordering ASC LIMIT ?
+ """
+ args.extend(
+ (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
)
- args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
@@ -418,24 +537,36 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def get_no_receipt(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool]]:
- sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " ep.highlight "
- " FROM event_push_actions AS ep"
- " INNER JOIN events AS e USING (room_id, event_id)"
- " WHERE"
- " ep.room_id NOT IN ("
- " SELECT room_id FROM receipts_linearized"
- " WHERE receipt_type = 'm.read' AND user_id = ?"
- " GROUP BY room_id"
- " )"
- " AND ep.user_id = ?"
- " AND ep.stream_ordering > ?"
- " AND ep.stream_ordering <= ?"
- " AND ep.notif = 1"
- " ORDER BY ep.stream_ordering ASC LIMIT ?"
+ receipt_types_clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.UNSTABLE_READ_PRIVATE,
+ ),
+ )
+
+ sql = f"""
+ SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+ ep.highlight
+ FROM event_push_actions AS ep
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ ep.room_id NOT IN (
+ SELECT room_id FROM receipts_linearized
+ WHERE {receipt_types_clause} AND user_id = ?
+ GROUP BY room_id
+ )
+ AND ep.user_id = ?
+ AND ep.stream_ordering > ?
+ AND ep.stream_ordering <= ?
+ AND ep.notif = 1
+ ORDER BY ep.stream_ordering ASC LIMIT ?
+ """
+ args.extend(
+ (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
)
- args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
@@ -485,34 +616,47 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will be ordered by descending received_ts.
The list will have between 0~limit entries.
"""
+
# find rooms that have a read receipt in them and return the most recent
# push actions
def get_after_receipt(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool, int]]:
- sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " ep.highlight, e.received_ts"
- " FROM ("
- " SELECT room_id,"
- " MAX(stream_ordering) as stream_ordering"
- " FROM events"
- " INNER JOIN receipts_linearized USING (room_id, event_id)"
- " WHERE receipt_type = 'm.read' AND user_id = ?"
- " GROUP BY room_id"
- ") AS rl,"
- " event_push_actions AS ep"
- " INNER JOIN events AS e USING (room_id, event_id)"
- " WHERE"
- " ep.room_id = rl.room_id"
- " AND ep.stream_ordering > rl.stream_ordering"
- " AND ep.user_id = ?"
- " AND ep.stream_ordering > ?"
- " AND ep.stream_ordering <= ?"
- " AND ep.notif = 1"
- " ORDER BY ep.stream_ordering DESC LIMIT ?"
+ receipt_types_clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.UNSTABLE_READ_PRIVATE,
+ ),
+ )
+
+ sql = f"""
+ SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+ ep.highlight, e.received_ts
+ FROM (
+ SELECT room_id,
+ MAX(stream_ordering) as stream_ordering
+ FROM events
+ INNER JOIN receipts_linearized USING (room_id, event_id)
+ WHERE {receipt_types_clause} AND user_id = ?
+ GROUP BY room_id
+ ) AS rl,
+ event_push_actions AS ep
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ ep.room_id = rl.room_id
+ AND ep.stream_ordering > rl.stream_ordering
+ AND ep.user_id = ?
+ AND ep.stream_ordering > ?
+ AND ep.stream_ordering <= ?
+ AND ep.notif = 1
+ ORDER BY ep.stream_ordering DESC LIMIT ?
+ """
+ args.extend(
+ (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
)
- args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
@@ -526,24 +670,36 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def get_no_receipt(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool, int]]:
- sql = (
- "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
- " ep.highlight, e.received_ts"
- " FROM event_push_actions AS ep"
- " INNER JOIN events AS e USING (room_id, event_id)"
- " WHERE"
- " ep.room_id NOT IN ("
- " SELECT room_id FROM receipts_linearized"
- " WHERE receipt_type = 'm.read' AND user_id = ?"
- " GROUP BY room_id"
- " )"
- " AND ep.user_id = ?"
- " AND ep.stream_ordering > ?"
- " AND ep.stream_ordering <= ?"
- " AND ep.notif = 1"
- " ORDER BY ep.stream_ordering DESC LIMIT ?"
+ receipt_types_clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.UNSTABLE_READ_PRIVATE,
+ ),
+ )
+
+ sql = f"""
+ SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+ ep.highlight, e.received_ts
+ FROM event_push_actions AS ep
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ ep.room_id NOT IN (
+ SELECT room_id FROM receipts_linearized
+ WHERE {receipt_types_clause} AND user_id = ?
+ GROUP BY room_id
+ )
+ AND ep.user_id = ?
+ AND ep.stream_ordering > ?
+ AND ep.stream_ordering <= ?
+ AND ep.notif = 1
+ ORDER BY ep.stream_ordering DESC LIMIT ?
+ """
+ args.extend(
+ (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
)
- args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
@@ -606,7 +762,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
async def add_push_actions_to_staging(
self,
event_id: str,
- user_id_actions: Dict[str, List[Union[dict, str]]],
+ user_id_actions: Dict[str, Collection[Union[Mapping, str]]],
count_as_unread: bool,
) -> None:
"""Add the push actions for the event to the push action staging area.
@@ -623,7 +779,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# This is a helper function for generating the necessary tuple that
# can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(
- user_id: str, actions: List[Union[dict, str]]
+ user_id: str, actions: Collection[Union[Mapping, str]]
) -> Tuple[str, str, str, int, int, int]:
is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
@@ -769,12 +925,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# [10, <none>, 20], we should treat this as being equivalent to
# [10, 10, 20].
#
- sql = (
- "SELECT received_ts FROM events"
- " WHERE stream_ordering <= ?"
- " ORDER BY stream_ordering DESC"
- " LIMIT 1"
- )
+ sql = """
+ SELECT received_ts FROM events
+ WHERE stream_ordering <= ?
+ ORDER BY stream_ordering DESC
+ LIMIT 1
+ """
while range_end - range_start > 0:
middle = (range_end + range_start) // 2
@@ -802,14 +958,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
self, stream_ordering: int
) -> Optional[int]:
def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
- sql = (
- "SELECT e.received_ts"
- " FROM event_push_actions AS ep"
- " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
- " WHERE ep.stream_ordering > ? AND notif = 1"
- " ORDER BY ep.stream_ordering ASC"
- " LIMIT 1"
- )
+ sql = """
+ SELECT e.received_ts
+ FROM event_push_actions AS ep
+ JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id
+ WHERE ep.stream_ordering > ? AND notif = 1
+ ORDER BY ep.stream_ordering ASC
+ LIMIT 1
+ """
txn.execute(sql, (stream_ordering,))
return cast(Optional[Tuple[int]], txn.fetchone())
@@ -858,10 +1014,13 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
Any push actions which predate the user's most recent read receipt are
now redundant, so we can remove them from `event_push_actions` and
update `event_push_summary`.
+
+ Returns true if all new receipts have been processed.
"""
limit = 100
+ # The (inclusive) receipt stream ID that was previously processed..
min_receipts_stream_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_push_summary_last_receipt_stream_id",
@@ -871,6 +1030,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
max_receipts_stream_id = self._receipts_id_gen.get_current_token()
+ # The (inclusive) event stream ordering that was previously summarised.
+ old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="event_push_summary_stream_ordering",
+ keyvalues={},
+ retcol="stream_ordering",
+ )
+
sql = """
SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering
FROM receipts_linearized AS r
@@ -895,13 +1062,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
rows = txn.fetchall()
- old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="event_push_summary_stream_ordering",
- keyvalues={},
- retcol="stream_ordering",
- )
-
# For each new read receipt we delete push actions from before it and
# recalculate the summary.
for _, room_id, user_id, stream_ordering in rows:
@@ -920,10 +1080,13 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
(room_id, user_id, stream_ordering),
)
+ # Fetch the notification counts between the stream ordering of the
+ # latest receipt and what was previously summarised.
notif_count, unread_count = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
)
+ # Replace the previous summary with the new counts.
self.db_pool.simple_upsert_txn(
txn,
table="event_push_summary",
@@ -956,10 +1119,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return len(rows) < limit
def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool:
- """Archives older notifications into event_push_summary. Returns whether
- the archiving process has caught up or not.
+ """Archives older notifications (from event_push_actions) into event_push_summary.
+
+ Returns whether the archiving process has caught up or not.
"""
+ # The (inclusive) event stream ordering that was previously summarised.
old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
@@ -974,7 +1139,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
SELECT stream_ordering FROM event_push_actions
WHERE stream_ordering > ?
ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
- """,
+ """,
(old_rotate_stream_ordering, self._rotate_count),
)
stream_row = txn.fetchone()
@@ -993,19 +1158,29 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
- self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering)
+ self._rotate_notifs_before_txn(
+ txn, old_rotate_stream_ordering, rotate_to_stream_ordering
+ )
return caught_up
def _rotate_notifs_before_txn(
- self, txn: LoggingTransaction, rotate_to_stream_ordering: int
+ self,
+ txn: LoggingTransaction,
+ old_rotate_stream_ordering: int,
+ rotate_to_stream_ordering: int,
) -> None:
- old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="event_push_summary_stream_ordering",
- keyvalues={},
- retcol="stream_ordering",
- )
+ """Archives older notifications (from event_push_actions) into event_push_summary.
+
+ Any event_push_actions between old_rotate_stream_ordering (exclusive) and
+ rotate_to_stream_ordering (inclusive) will be added to the event_push_summary
+ table.
+
+ Args:
+ txn: The database transaction.
+ old_rotate_stream_ordering: The previous maximum event stream ordering.
+ rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
+ """
# Calculate the new counts that should be upserted into event_push_summary
sql = """
@@ -1090,12 +1265,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
(rotate_to_stream_ordering,),
)
- async def _remove_old_push_actions_that_have_rotated(
- self,
- ) -> None:
- """Clear out old push actions that have been summarized."""
+ async def _remove_old_push_actions_that_have_rotated(self) -> None:
+ """Clear out old push actions that have been summarised."""
- # We want to clear out anything that older than a day that *has* already
+ # We want to clear out anything that is older than a day that *has* already
# been rotated.
rotated_upto_stream_ordering = await self.db_pool.simple_select_one_onecol(
table="event_push_summary_stream_ordering",
@@ -1119,7 +1292,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
SELECT stream_ordering FROM event_push_actions
WHERE stream_ordering <= ? AND highlight = 0
ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
- """,
+ """,
(
max_stream_ordering_to_delete,
batch_size,
@@ -1215,16 +1388,18 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# NB. This assumes event_ids are globally unique since
# it makes the query easier to index
- sql = (
- "SELECT epa.event_id, epa.room_id,"
- " epa.stream_ordering, epa.topological_ordering,"
- " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
- " FROM event_push_actions epa, events e"
- " WHERE epa.event_id = e.event_id"
- " AND epa.user_id = ? %s"
- " AND epa.notif = 1"
- " ORDER BY epa.stream_ordering DESC"
- " LIMIT ?" % (before_clause,)
+ sql = """
+ SELECT epa.event_id, epa.room_id,
+ epa.stream_ordering, epa.topological_ordering,
+ epa.actions, epa.highlight, epa.profile_tag, e.received_ts
+ FROM event_push_actions epa, events e
+ WHERE epa.event_id = e.event_id
+ AND epa.user_id = ? %s
+ AND epa.notif = 1
+ ORDER BY epa.stream_ordering DESC
+ LIMIT ?
+ """ % (
+ before_clause,
)
txn.execute(sql, args)
return cast(
@@ -1247,7 +1422,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
]
-def _action_has_highlight(actions: List[Union[dict, str]]) -> bool:
+def _action_has_highlight(actions: Collection[Union[Mapping, str]]) -> bool:
for action in actions:
if not isinstance(action, dict):
continue
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1f600f1190..a4010ee28d 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -40,6 +40,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
+from synapse.logging.opentracing import trace
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -145,6 +146,7 @@ class PersistEventsStore:
self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+ @trace
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
@@ -1490,7 +1492,7 @@ class PersistEventsStore:
event.sender,
"url" in event.content and isinstance(event.content["url"], str),
event.get_state_key(),
- context.rejected or None,
+ context.rejected,
)
for event, context in events_and_contexts
),
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 5914a35420..8a7cdb024d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -54,6 +54,7 @@ from synapse.logging.context import (
current_context,
make_deferred_yieldable,
)
+from synapse.logging.opentracing import start_active_span, tag_args, trace
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
@@ -430,6 +431,8 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event_id: e for e in events}
+ @trace
+ @tag_args
async def get_events_as_list(
self,
event_ids: Collection[str],
@@ -600,7 +603,11 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
map from event id to result
"""
- event_entry_map = await self._get_events_from_cache(
+ # Shortcut: check if we have any events in the *in memory* cache - this function
+ # may be called repeatedly for the same event so at this point we cannot reach
+ # out to any external cache for performance reasons. The external cache is
+ # checked later on in the `get_missing_events_from_cache_or_db` function below.
+ event_entry_map = self._get_events_from_local_cache(
event_ids,
)
@@ -632,7 +639,9 @@ class EventsWorkerStore(SQLBaseStore):
if missing_events_ids:
- async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
+ async def get_missing_events_from_cache_or_db() -> Dict[
+ str, EventCacheEntry
+ ]:
"""Fetches the events in `missing_event_ids` from the database.
Also creates entries in `self._current_event_fetches` to allow
@@ -657,10 +666,18 @@ class EventsWorkerStore(SQLBaseStore):
# the events have been redacted, and if so pulling the redaction event
# out of the database to check it.
#
+ missing_events = {}
try:
- missing_events = await self._get_events_from_db(
+ # Try to fetch from any external cache. We already checked the
+ # in-memory cache above.
+ missing_events = await self._get_events_from_external_cache(
missing_events_ids,
)
+ # Now actually fetch any remaining events from the DB
+ db_missing_events = await self._get_events_from_db(
+ missing_events_ids - missing_events.keys(),
+ )
+ missing_events.update(db_missing_events)
except Exception as e:
with PreserveLoggingContext():
fetching_deferred.errback(e)
@@ -679,7 +696,7 @@ class EventsWorkerStore(SQLBaseStore):
# cancellations, since multiple `_get_events_from_cache_or_db` calls can
# reuse the same fetch.
missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
- get_missing_events_from_db()
+ get_missing_events_from_cache_or_db()
)
event_entry_map.update(missing_events)
@@ -754,7 +771,54 @@ class EventsWorkerStore(SQLBaseStore):
async def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, EventCacheEntry]:
- """Fetch events from the caches.
+ """Fetch events from the caches, both in memory and any external.
+
+ May return rejected events.
+
+ Args:
+ events: list of event_ids to fetch
+ update_metrics: Whether to update the cache hit ratio metrics
+ """
+ event_map = self._get_events_from_local_cache(
+ events, update_metrics=update_metrics
+ )
+
+ missing_event_ids = (e for e in events if e not in event_map)
+ event_map.update(
+ await self._get_events_from_external_cache(
+ events=missing_event_ids,
+ update_metrics=update_metrics,
+ )
+ )
+
+ return event_map
+
+ async def _get_events_from_external_cache(
+ self, events: Iterable[str], update_metrics: bool = True
+ ) -> Dict[str, EventCacheEntry]:
+ """Fetch events from any configured external cache.
+
+ May return rejected events.
+
+ Args:
+ events: list of event_ids to fetch
+ update_metrics: Whether to update the cache hit ratio metrics
+ """
+ event_map = {}
+
+ for event_id in events:
+ ret = await self._get_event_cache.get_external(
+ (event_id,), None, update_metrics=update_metrics
+ )
+ if ret:
+ event_map[event_id] = ret
+
+ return event_map
+
+ def _get_events_from_local_cache(
+ self, events: Iterable[str], update_metrics: bool = True
+ ) -> Dict[str, EventCacheEntry]:
+ """Fetch events from the local, in memory, caches.
May return rejected events.
@@ -766,7 +830,7 @@ class EventsWorkerStore(SQLBaseStore):
for event_id in events:
# First check if it's in the event cache
- ret = await self._get_event_cache.get(
+ ret = self._get_event_cache.get_local(
(event_id,), None, update_metrics=update_metrics
)
if ret:
@@ -788,7 +852,7 @@ class EventsWorkerStore(SQLBaseStore):
# We add the entry back into the cache as we want to keep
# recently queried events in the cache.
- await self._get_event_cache.set((event_id,), cache_entry)
+ self._get_event_cache.set_local((event_id,), cache_entry)
return event_map
@@ -1029,23 +1093,42 @@ class EventsWorkerStore(SQLBaseStore):
"""
fetched_event_ids: Set[str] = set()
fetched_events: Dict[str, _EventRow] = {}
- events_to_fetch = event_ids
- while events_to_fetch:
- row_map = await self._enqueue_events(events_to_fetch)
+ async def _fetch_event_ids_and_get_outstanding_redactions(
+ event_ids_to_fetch: Collection[str],
+ ) -> Collection[str]:
+ """
+ Fetch all of the given event_ids and return any associated redaction event_ids
+ that we still need to fetch in the next iteration.
+ """
+ row_map = await self._enqueue_events(event_ids_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids: Set[str] = set()
- for event_id in events_to_fetch:
+ for event_id in event_ids_to_fetch:
row = row_map.get(event_id)
fetched_event_ids.add(event_id)
if row:
fetched_events[event_id] = row
redaction_ids.update(row.redactions)
- events_to_fetch = redaction_ids.difference(fetched_event_ids)
- if events_to_fetch:
- logger.debug("Also fetching redaction events %s", events_to_fetch)
+ event_ids_to_fetch = redaction_ids.difference(fetched_event_ids)
+ return event_ids_to_fetch
+
+ # Grab the initial list of events requested
+ event_ids_to_fetch = await _fetch_event_ids_and_get_outstanding_redactions(
+ event_ids
+ )
+ # Then go and recursively find all of the associated redactions
+ with start_active_span("recursively fetching redactions"):
+ while event_ids_to_fetch:
+ logger.debug("Also fetching redaction events %s", event_ids_to_fetch)
+
+ event_ids_to_fetch = (
+ await _fetch_event_ids_and_get_outstanding_redactions(
+ event_ids_to_fetch
+ )
+ )
# build a map from event_id to EventBase
event_map: Dict[str, EventBase] = {}
@@ -1363,6 +1446,8 @@ class EventsWorkerStore(SQLBaseStore):
return {r["event_id"] for r in rows}
+ @trace
+ @tag_args
async def have_seen_events(
self, room_id: str, event_ids: Iterable[str]
) -> Set[str]:
@@ -2110,14 +2195,92 @@ class EventsWorkerStore(SQLBaseStore):
def _get_partial_state_events_batch_txn(
txn: LoggingTransaction, room_id: str
) -> List[str]:
+ # we want to work through the events from oldest to newest, so
+ # we only want events whose prev_events do *not* have partial state - hence
+ # the 'NOT EXISTS' clause in the below.
+ #
+ # This is necessary because ordering by stream ordering isn't quite enough
+ # to ensure that we work from oldest to newest event (in particular,
+ # if an event is initially persisted as an outlier and later de-outliered,
+ # it can end up with a lower stream_ordering than its prev_events).
+ #
+ # Typically this means we'll only return one event per batch, but that's
+ # hard to do much about.
+ #
+ # See also: https://github.com/matrix-org/synapse/issues/13001
txn.execute(
"""
SELECT event_id FROM partial_state_events AS pse
JOIN events USING (event_id)
- WHERE pse.room_id = ?
+ WHERE pse.room_id = ? AND
+ NOT EXISTS(
+ SELECT 1 FROM event_edges AS ee
+ JOIN partial_state_events AS prev_pse ON (prev_pse.event_id=ee.prev_event_id)
+ WHERE ee.event_id=pse.event_id
+ )
ORDER BY events.stream_ordering
LIMIT 100
""",
(room_id,),
)
return [row[0] for row in txn]
+
+ def mark_event_rejected_txn(
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ rejection_reason: Optional[str],
+ ) -> None:
+ """Mark an event that was previously accepted as rejected, or vice versa
+
+ This can happen, for example, when resyncing state during a faster join.
+
+ Args:
+ txn:
+ event_id: ID of event to update
+ rejection_reason: reason it has been rejected, or None if it is now accepted
+ """
+ if rejection_reason is None:
+ logger.info(
+ "Marking previously-processed event %s as accepted",
+ event_id,
+ )
+ self.db_pool.simple_delete_txn(
+ txn,
+ "rejections",
+ keyvalues={"event_id": event_id},
+ )
+ else:
+ logger.info(
+ "Marking previously-processed event %s as rejected(%s)",
+ event_id,
+ rejection_reason,
+ )
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="rejections",
+ keyvalues={"event_id": event_id},
+ values={
+ "reason": rejection_reason,
+ "last_check": self._clock.time_msec(),
+ },
+ )
+ self.db_pool.simple_update_txn(
+ txn,
+ table="events",
+ keyvalues={"event_id": event_id},
+ updatevalues={"rejection_reason": rejection_reason},
+ )
+
+ self.invalidate_get_event_cache_after_txn(txn, event_id)
+
+ # TODO(faster_joins): invalidate the cache on workers. Ideally we'd just
+ # call '_send_invalidation_to_replication', but we actually need the other
+ # end to call _invalidate_local_get_event_cache() rather than (just)
+ # _get_event_cache.invalidate().
+ #
+ # One solution might be to (somehow) get the workers to call
+ # _invalidate_caches_for_event() (though that will invalidate more than
+ # strictly necessary).
+ #
+ # https://github.com/matrix-org/synapse/issues/12994
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 768f95d16c..5079edd1e0 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,11 +14,23 @@
# limitations under the License.
import abc
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+ cast,
+)
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
-from synapse.push.baserules import list_with_base_rules
+from synapse.push.baserules import FilteredPushRules, PushRule, compile_push_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -50,60 +62,30 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-def _is_experimental_rule_enabled(
- rule_id: str, experimental_config: ExperimentalConfig
-) -> bool:
- """Used by `_load_rules` to filter out experimental rules when they
- have not been enabled.
- """
- if (
- rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
- and not experimental_config.msc3786_enabled
- ):
- return False
- if (
- rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
- and not experimental_config.msc3772_enabled
- ):
- return False
- return True
-
-
def _load_rules(
rawrules: List[JsonDict],
enabled_map: Dict[str, bool],
experimental_config: ExperimentalConfig,
-) -> List[JsonDict]:
- ruleslist = []
- for rawrule in rawrules:
- rule = dict(rawrule)
- rule["conditions"] = db_to_json(rawrule["conditions"])
- rule["actions"] = db_to_json(rawrule["actions"])
- rule["default"] = False
- ruleslist.append(rule)
-
- # We're going to be mutating this a lot, so copy it. We also filter out
- # any experimental default push rules that aren't enabled.
- rules = [
- rule
- for rule in list_with_base_rules(ruleslist)
- if _is_experimental_rule_enabled(rule["rule_id"], experimental_config)
- ]
+) -> FilteredPushRules:
+ """Take the DB rows returned from the DB and convert them into a full
+ `FilteredPushRules` object.
+ """
- for i, rule in enumerate(rules):
- rule_id = rule["rule_id"]
+ ruleslist = [
+ PushRule(
+ rule_id=rawrule["rule_id"],
+ priority_class=rawrule["priority_class"],
+ conditions=db_to_json(rawrule["conditions"]),
+ actions=db_to_json(rawrule["actions"]),
+ )
+ for rawrule in rawrules
+ ]
- if rule_id not in enabled_map:
- continue
- if rule.get("enabled", True) == bool(enabled_map[rule_id]):
- continue
+ push_rules = compile_push_rules(ruleslist)
- # Rules are cached across users.
- rule = dict(rule)
- rule["enabled"] = bool(enabled_map[rule_id])
- rules[i] = rule
+ filtered_rules = FilteredPushRules(push_rules, enabled_map, experimental_config)
- return rules
+ return filtered_rules
# The ABCMeta metaclass ensures that it cannot be instantiated without
@@ -162,7 +144,7 @@ class PushRulesWorkerStore(
raise NotImplementedError()
@cached(max_entries=5000)
- async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
+ async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
@@ -183,7 +165,6 @@ class PushRulesWorkerStore(
return _load_rules(rows, enabled_map, self.hs.config.experimental)
- @cached(max_entries=5000)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
table="push_rules_enable",
@@ -216,11 +197,11 @@ class PushRulesWorkerStore(
@cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
async def bulk_get_push_rules(
self, user_ids: Collection[str]
- ) -> Dict[str, List[JsonDict]]:
+ ) -> Dict[str, FilteredPushRules]:
if not user_ids:
return {}
- results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
+ raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
@@ -234,20 +215,19 @@ class PushRulesWorkerStore(
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
for row in rows:
- results.setdefault(row["user_name"], []).append(row)
+ raw_rules.setdefault(row["user_name"], []).append(row)
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
- for user_id, rules in results.items():
+ results: Dict[str, FilteredPushRules] = {}
+
+ for user_id, rules in raw_rules.items():
results[user_id] = _load_rules(
rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
)
return results
- @cachedList(
- cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids"
- )
async def bulk_get_push_rules_enabled(
self, user_ids: Collection[str]
) -> Dict[str, Dict[str, bool]]:
@@ -262,6 +242,7 @@ class PushRulesWorkerStore(
iterable=user_ids,
retcols=("user_name", "rule_id", "enabled"),
desc="bulk_get_push_rules_enabled",
+ batch_size=1000,
)
for row in rows:
enabled = bool(row["enabled"])
@@ -345,8 +326,8 @@ class PushRuleStore(PushRulesWorkerStore):
user_id: str,
rule_id: str,
priority_class: int,
- conditions: List[Dict[str, str]],
- actions: List[Union[JsonDict, str]],
+ conditions: Sequence[Mapping[str, str]],
+ actions: Sequence[Union[Mapping[str, Any], str]],
before: Optional[str] = None,
after: Optional[str] = None,
) -> None:
@@ -808,7 +789,6 @@ class PushRuleStore(PushRulesWorkerStore):
self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
- txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
txn.call_after(
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
@@ -817,7 +797,7 @@ class PushRuleStore(PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
async def copy_push_rule_from_room_to_room(
- self, new_room_id: str, user_id: str, rule: dict
+ self, new_room_id: str, user_id: str, rule: PushRule
) -> None:
"""Copy a single push rule from one room to another for a specific user.
@@ -827,21 +807,27 @@ class PushRuleStore(PushRulesWorkerStore):
rule: A push rule.
"""
# Create new rule id
- rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+ rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
new_rule_id = rule_id_scope + "/" + new_room_id
+ new_conditions = []
+
# Change room id in each condition
- for condition in rule.get("conditions", []):
+ for condition in rule.conditions:
+ new_condition = condition
if condition.get("key") == "room_id":
- condition["pattern"] = new_room_id
+ new_condition = dict(condition)
+ new_condition["pattern"] = new_room_id
+
+ new_conditions.append(new_condition)
# Add the rule for the new room
await self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
- priority_class=rule["priority_class"],
- conditions=rule["conditions"],
- actions=rule["actions"],
+ priority_class=rule.priority_class,
+ conditions=new_conditions,
+ actions=rule.actions,
)
async def copy_push_rules_from_room_to_room_for_user(
@@ -859,8 +845,11 @@ class PushRuleStore(PushRulesWorkerStore):
user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
- for rule in user_push_rules:
- conditions = rule.get("conditions", [])
+ for rule, enabled in user_push_rules:
+ if not enabled:
+ continue
+
+ conditions = rule.conditions
if any(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 0090c9f225..124c70ad37 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -161,7 +161,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type: The receipt types to fetch.
Returns:
- The latest receipt, if one exists.
+ The event ID and stream ordering of the latest receipt, if one exists.
"""
clause, args = make_in_list_sql_clause(
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index cb63cd9b7d..7fb9c801da 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -69,9 +69,9 @@ class TokenLookupResult:
"""
user_id: str
+ token_id: int
is_guest: bool = False
shadow_banned: bool = False
- token_id: Optional[int] = None
device_id: Optional[str] = None
valid_until_ms: Optional[int] = None
token_owner: str = attr.ib()
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b457bc189e..7bd27790eb 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -62,7 +62,6 @@ class RelationsWorkerStore(SQLBaseStore):
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
- aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[StreamToken] = None,
@@ -76,7 +75,6 @@ class RelationsWorkerStore(SQLBaseStore):
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
- aggregation_key: Only fetch events with this aggregation key, if given.
limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`).
@@ -105,10 +103,6 @@ class RelationsWorkerStore(SQLBaseStore):
where_clause.append("type = ?")
where_args.append(event_type)
- if aggregation_key:
- where_clause.append("aggregation_key = ?")
- where_args.append(aggregation_key)
-
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index d6d485507b..b7d4baa6bb 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -207,7 +207,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _construct_room_type_where_clause(
self, room_types: Union[List[Union[str, None]], None]
) -> Tuple[Union[str, None], List[str]]:
- if not room_types or not self.config.experimental.msc3827_enabled:
+ if not room_types:
return None, []
else:
# We use None when we want get rooms without a type
@@ -2001,9 +2001,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+ # We join on room_stats_state despite not using any columns from it
+ # because the join can influence the number of rows returned;
+ # e.g. a room that doesn't have state, maybe because it was deleted.
+ # The query returning the total count should be consistent with
+ # the query returning the results.
sql = """
SELECT COUNT(*) as total_event_reports
FROM event_reports AS er
+ JOIN room_stats_state ON room_stats_state.room_id = er.room_id
{}
""".format(
where_clause
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index df6b82660e..046ad3a11c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -21,6 +21,7 @@ from typing import (
FrozenSet,
Iterable,
List,
+ Mapping,
Optional,
Set,
Tuple,
@@ -55,6 +56,7 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
+from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -183,7 +185,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self._check_safe_current_state_events_membership_updated_txn,
)
- @cached(max_entries=100000, iterable=True, prune_unread_entries=False)
+ @cached(max_entries=100000, iterable=True)
async def get_users_in_room(self, room_id: str) -> List[str]:
return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
@@ -281,6 +283,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Returns:
A mapping from user ID to ProfileInfo.
+
+ Preconditions:
+ - There is full state available for the room (it is not partial-stated).
"""
def _get_users_in_room_with_profiles(
@@ -561,7 +566,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results_dict.get("membership"), results_dict.get("event_id")
- @cached(max_entries=500000, iterable=True, prune_unread_entries=False)
+ @cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering(
self, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
@@ -732,25 +737,76 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
return frozenset(r.room_id for r in rooms)
- @cached(
- max_entries=500000,
- cache_context=True,
- iterable=True,
- prune_unread_entries=False,
+ @cached(max_entries=10000)
+ async def does_pair_of_users_share_a_room(
+ self, user_id: str, other_user_id: str
+ ) -> bool:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="does_pair_of_users_share_a_room", list_name="other_user_ids"
)
- async def get_users_who_share_room_with_user(
- self, user_id: str, cache_context: _CacheContext
+ async def _do_users_share_a_room(
+ self, user_id: str, other_user_ids: Collection[str]
+ ) -> Mapping[str, Optional[bool]]:
+ """Return mapping from user ID to whether they share a room with the
+ given user.
+
+ Note: `None` and `False` are equivalent and mean they don't share a
+ room.
+ """
+
+ def do_users_share_a_room_txn(
+ txn: LoggingTransaction, user_ids: Collection[str]
+ ) -> Dict[str, bool]:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "state_key", user_ids
+ )
+
+ # This query works by fetching both the list of rooms for the target
+ # user and the set of other users, and then checking if there is any
+ # overlap.
+ sql = f"""
+ SELECT b.state_key
+ FROM (
+ SELECT room_id FROM current_state_events
+ WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ?
+ ) AS a
+ INNER JOIN (
+ SELECT room_id, state_key FROM current_state_events
+ WHERE type = 'm.room.member' AND membership = 'join' AND {clause}
+ ) AS b using (room_id)
+ LIMIT 1
+ """
+
+ txn.execute(sql, (user_id, *args))
+ return {u: True for u, in txn}
+
+ to_return = {}
+ for batch_user_ids in batch_iter(other_user_ids, 1000):
+ res = await self.db_pool.runInteraction(
+ "do_users_share_a_room", do_users_share_a_room_txn, batch_user_ids
+ )
+ to_return.update(res)
+
+ return to_return
+
+ async def do_users_share_a_room(
+ self, user_id: str, other_user_ids: Collection[str]
) -> Set[str]:
+ """Return the set of users who share a room with the first users"""
+
+ user_dict = await self._do_users_share_a_room(user_id, other_user_ids)
+
+ return {u for u, share_room in user_dict.items() if share_room}
+
+ async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]:
"""Returns the set of users who share a room with `user_id`"""
- room_ids = await self.get_rooms_for_user(
- user_id, on_invalidate=cache_context.invalidate
- )
+ room_ids = await self.get_rooms_for_user(user_id)
user_who_share_room = set()
for room_id in room_ids:
- user_ids = await self.get_users_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
+ user_ids = await self.get_users_in_room(room_id)
user_who_share_room.update(user_ids)
return user_who_share_room
@@ -779,9 +835,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return shared_room_ids or frozenset()
- async def get_joined_users_from_state(
+ async def get_joined_user_ids_from_state(
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
- ) -> Dict[str, ProfileInfo]:
+ ) -> Set[str]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -792,25 +848,25 @@ class RoomMemberWorkerStore(EventsWorkerStore):
assert state_group is not None
with Measure(self._clock, "get_joined_users_from_state"):
- return await self._get_joined_users_from_context(
+ return await self._get_joined_user_ids_from_context(
room_id, state_group, state, context=state_entry
)
@cached(num_args=2, iterable=True, max_entries=100000)
- async def _get_joined_users_from_context(
+ async def _get_joined_user_ids_from_context(
self,
room_id: str,
state_group: Union[object, int],
current_state_ids: StateMap[str],
event: Optional[EventBase] = None,
context: Optional["_StateCacheEntry"] = None,
- ) -> Dict[str, ProfileInfo]:
+ ) -> Set[str]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.
assert state_group is not None
- users_in_room = {}
+ users_in_room = set()
member_event_ids = [
e_id
for key, e_id in current_state_ids.items()
@@ -823,11 +879,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# If we do then we can reuse that result and simply update it with
# any membership changes in `delta_ids`
if context.prev_group and context.delta_ids:
- prev_res = self._get_joined_users_from_context.cache.get_immediate(
+ prev_res = self._get_joined_user_ids_from_context.cache.get_immediate(
(room_id, context.prev_group), None
)
- if prev_res and isinstance(prev_res, dict):
- users_in_room = dict(prev_res)
+ if prev_res and isinstance(prev_res, set):
+ users_in_room = prev_res
member_event_ids = [
e_id
for key, e_id in context.delta_ids.items()
@@ -835,7 +891,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
]
for etype, state_key in context.delta_ids:
if etype == EventTypes.Member:
- users_in_room.pop(state_key, None)
+ users_in_room.discard(state_key)
# We check if we have any of the member event ids in the event cache
# before we ask the DB
@@ -843,7 +899,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We don't update the event cache hit ratio as it completely throws off
# the hit ratio counts. After all, we don't populate the cache if we
# miss it here
- event_map = await self._get_events_from_cache(
+ event_map = self._get_events_from_local_cache(
member_event_ids, update_metrics=False
)
@@ -852,71 +908,64 @@ class RoomMemberWorkerStore(EventsWorkerStore):
ev_entry = event_map.get(event_id)
if ev_entry and not ev_entry.event.rejected_reason:
if ev_entry.event.membership == Membership.JOIN:
- users_in_room[ev_entry.event.state_key] = ProfileInfo(
- display_name=ev_entry.event.content.get("displayname", None),
- avatar_url=ev_entry.event.content.get("avatar_url", None),
- )
+ users_in_room.add(ev_entry.event.state_key)
else:
missing_member_event_ids.append(event_id)
if missing_member_event_ids:
- event_to_memberships = await self._get_joined_profiles_from_event_ids(
+ event_to_memberships = await self._get_user_ids_from_membership_event_ids(
missing_member_event_ids
)
- users_in_room.update(row for row in event_to_memberships.values() if row)
+ users_in_room.update(
+ user_id for user_id in event_to_memberships.values() if user_id
+ )
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
if event.event_id in member_event_ids:
- users_in_room[event.state_key] = ProfileInfo(
- display_name=event.content.get("displayname", None),
- avatar_url=event.content.get("avatar_url", None),
- )
+ users_in_room.add(event.state_key)
return users_in_room
- @cached(max_entries=10000)
- def _get_joined_profile_from_event_id(
+ @cached(
+ max_entries=10000,
+ # This name matches the old function that has been replaced - the cache name
+ # is kept here to maintain backwards compatibility.
+ name="_get_joined_profile_from_event_id",
+ )
+ def _get_user_id_from_membership_event_id(
self, event_id: str
) -> Optional[Tuple[str, ProfileInfo]]:
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_joined_profile_from_event_id",
+ cached_method_name="_get_user_id_from_membership_event_id",
list_name="event_ids",
)
- async def _get_joined_profiles_from_event_ids(
+ async def _get_user_ids_from_membership_event_ids(
self, event_ids: Iterable[str]
- ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
+ ) -> Dict[str, Optional[str]]:
"""For given set of member event_ids check if they point to a join
- event and if so return the associated user and profile info.
+ event.
Args:
event_ids: The member event IDs to lookup
Returns:
- Map from event ID to `user_id` and ProfileInfo (or None if not join event).
+ Map from event ID to `user_id`, or None if event is not a join.
"""
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
- retcols=("user_id", "display_name", "avatar_url", "event_id"),
+ retcols=("user_id", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=1000,
- desc="_get_joined_profiles_from_event_ids",
+ desc="_get_user_ids_from_membership_event_ids",
)
- return {
- row["event_id"]: (
- row["user_id"],
- ProfileInfo(
- avatar_url=row["avatar_url"], display_name=row["display_name"]
- ),
- )
- for row in rows
- }
+ return {row["event_id"]: row["user_id"] for row in rows}
@cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -1075,12 +1124,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
else:
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
- joined_users = await self.get_joined_users_from_state(
+ joined_user_ids = await self.get_joined_user_ids_from_state(
room_id, state, state_entry
)
cache.hosts_to_joined_users = {}
- for user_id in joined_users:
+ for user_id in joined_user_ids:
host = intern_string(get_domain_from_id(user_id))
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
@@ -1159,6 +1208,30 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
+ async def is_locally_forgotten_room(self, room_id: str) -> bool:
+ """Returns whether all local users have forgotten this room_id.
+
+ Args:
+ room_id: The room ID to query.
+
+ Returns:
+ Whether the room is forgotten.
+ """
+
+ sql = """
+ SELECT count(*) > 0 FROM local_current_membership
+ INNER JOIN room_memberships USING (room_id, event_id)
+ WHERE
+ room_id = ?
+ AND forgotten = 0;
+ """
+
+ rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+
+ # `count(*)` returns always an integer
+ # If any rows still exist it means someone has not forgotten this room yet
+ return not rows[0][0]
+
async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
"""Get all rooms that the user has ever been in.
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 9674c4a757..0b10af0e58 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -419,15 +419,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# anything that was rejected should have the same state as its
# predecessor.
if context.rejected:
- assert context.state_group == context.state_group_before_event
+ state_group = context.state_group_before_event
+ else:
+ state_group = context.state_group
self.db_pool.simple_update_txn(
txn,
table="event_to_state_groups",
keyvalues={"event_id": event.event_id},
- updatevalues={"state_group": context.state_group},
+ updatevalues={"state_group": state_group},
)
+ # the event may now be rejected where it was not before, or vice versa,
+ # in which case we need to update the rejected flags.
+ if bool(context.rejected) != (event.rejected_reason is not None):
+ self.mark_event_rejected_txn(txn, event.event_id, context.rejected)
+
self.db_pool.simple_delete_one_txn(
txn,
table="partial_state_events",
@@ -440,7 +447,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.call_after(
self._get_state_group_for_event.prefill,
(event.event_id,),
- context.state_group,
+ state_group,
)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 2590b52f73..a347430aa7 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -58,6 +58,7 @@ from twisted.internet import defer
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.opentracing import trace
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -1346,6 +1347,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, next_token
+ @trace
async def paginate_room_events(
self,
room_id: str,
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index afbc85ad0c..bb64543c1f 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -202,7 +202,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
requests state from the cache, if False we need to query the DB for the
missing state.
"""
- cache_entry = cache.get(group)
+ # If we are asked explicitly for a subset of keys, we only ask for those
+ # from the cache. This ensures that the `DictionaryCache` can make
+ # better decisions about what to cache and what to expire.
+ dict_keys = None
+ if not state_filter.has_wildcards():
+ dict_keys = state_filter.concrete_types()
+
+ cache_entry = cache.get(group, dict_keys=dict_keys)
state_dict_ids = cache_entry.value
if cache_entry.full or state_filter.is_full():
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index af3bab2c15..0004d955b4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -539,15 +539,6 @@ class StateFilter:
is_mine_id: a callable which confirms if a given state_key matches a mxid
of a local user
"""
-
- # TODO(faster_joins): it's not entirely clear that this is safe. In particular,
- # there may be circumstances in which we return a piece of state that, once we
- # resync the state, we discover is invalid. For example: if it turns out that
- # the sender of a piece of state wasn't actually in the room, then clearly that
- # state shouldn't have been returned.
- # We should at least add some tests around this to see what happens.
- # https://github.com/matrix-org/synapse/issues/13006
-
# if we haven't requested membership events, then it depends on the value of
# 'include_others'
if EventTypes.Member not in self.types:
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index 466e5137f2..b4bf49dace 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import trace_with_opname
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.util import unwrapFirstError
@@ -58,6 +59,7 @@ class PartialStateEventsTracker:
for o in observers:
o.callback(None)
+ @trace_with_opname("PartialStateEventsTracker.await_full_state")
async def await_full_state(self, event_ids: Collection[str]) -> None:
"""Wait for all the given events to have full state.
@@ -151,6 +153,7 @@ class PartialCurrentStateTracker:
for o in observers:
o.callback(None)
+ @trace_with_opname("PartialCurrentStateTracker.await_full_state")
async def await_full_state(self, room_id: str) -> None:
# We add the deferred immediately so that the DB call to check for
# partial state doesn't race when we unpartial the room.
|