diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 7cdc9fe98f..318fd7dc71 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -16,11 +16,11 @@ import logging
import threading
import weakref
from enum import Enum, auto
+from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
Collection,
- Container,
Dict,
Iterable,
List,
@@ -59,7 +59,6 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -77,10 +76,12 @@ from synapse.storage.util.id_generators import (
)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
+from synapse.types.state import StateFilter
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import AsyncLruCache
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -212,26 +213,35 @@ class EventsWorkerStore(SQLBaseStore):
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
- if hs.get_instance_name() in hs.config.worker.writers.events:
- self._stream_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- )
- self._backfill_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- step=-1,
- extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
- )
- else:
- self._stream_id_gen = SlavedIdTracker(
- db_conn, "events", "stream_ordering"
- )
- self._backfill_id_gen = SlavedIdTracker(
- db_conn, "events", "stream_ordering", step=-1
- )
+ self._stream_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
+ )
+ self._backfill_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ step=-1,
+ extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
+ )
+
+ events_max = self._stream_id_gen.get_current_token()
+ curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
+ db_conn,
+ "current_state_delta_stream",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=events_max, # As we share the stream id with events token
+ limit=1000,
+ )
+ self._curr_state_delta_stream_cache: StreamChangeCache = StreamChangeCache(
+ "_curr_state_delta_stream_cache",
+ min_curr_state_delta_id,
+ prefilled_cache=curr_state_delta_prefill,
+ )
if hs.config.worker.run_background_tasks:
# We periodically clean out old transaction ID mappings
@@ -374,7 +384,7 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none.
Returns:
- The event, or None if the event was not found.
+ The event, or None if the event was not found and allow_none is `True`.
"""
if not isinstance(event_id, str):
raise TypeError("Invalid event event_id %r" % (event_id,))
@@ -474,7 +484,7 @@ class EventsWorkerStore(SQLBaseStore):
return []
# there may be duplicates so we cast the list to a set
- event_entry_map = await self._get_events_from_cache_or_db(
+ event_entry_map = await self.get_unredacted_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected
)
@@ -509,7 +519,9 @@ class EventsWorkerStore(SQLBaseStore):
continue
redacted_event_id = entry.event.redacts
- event_map = await self._get_events_from_cache_or_db([redacted_event_id])
+ event_map = await self.get_unredacted_events_from_cache_or_db(
+ [redacted_event_id]
+ )
original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry:
# we don't have the redacted event (or it was rejected).
@@ -588,11 +600,16 @@ class EventsWorkerStore(SQLBaseStore):
return events
@cancellable
- async def _get_events_from_cache_or_db(
- self, event_ids: Iterable[str], allow_rejected: bool = False
+ async def get_unredacted_events_from_cache_or_db(
+ self,
+ event_ids: Iterable[str],
+ allow_rejected: bool = False,
) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
+ Note that the events pulled by this function will not have any redactions
+ applied, and no guarantee is made about the ordering of the events returned.
+
If events are pulled from the database, they will be cached for future lookups.
Unknown events are omitted from the response.
@@ -863,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_stripped_room_state_from_event_context(
self,
context: EventContext,
- state_types_to_include: Container[str],
+ state_keys_to_include: StateFilter,
membership_user_id: Optional[str] = None,
) -> List[JsonDict]:
"""
@@ -876,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore):
Args:
context: The event context to retrieve state of the room from.
- state_types_to_include: The type of state events to include.
+ state_keys_to_include: The state events to include, for each event type.
membership_user_id: An optional user ID to include the stripped membership state
events of. This is useful when generating the stripped state of a room for
invites. We want to send membership events of the inviter, so that the
@@ -885,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
A list of dictionaries, each representing a stripped state event from the room.
"""
- current_state_ids = await context.get_current_state_ids()
+ if membership_user_id:
+ types = chain(
+ state_keys_to_include.to_types(),
+ [(EventTypes.Member, membership_user_id)],
+ )
+ filter = StateFilter.from_types(types)
+ else:
+ filter = state_keys_to_include
+ selected_state_ids = await context.get_current_state_ids(filter)
# We know this event is not an outlier, so this must be
# non-None.
- assert current_state_ids is not None
-
- # The state to include
- state_to_include_ids = [
- e_id
- for k, e_id in current_state_ids.items()
- if k[0] in state_types_to_include
- or (membership_user_id and k == (EventTypes.Member, membership_user_id))
- ]
+ assert selected_state_ids is not None
+
+ # Confusingly, get_current_state_events may return events that are discarded by
+ # the filter, if they're in context._state_delta_due_to_event. Strip these away.
+ selected_state_ids = filter.filter_state(selected_state_ids)
- state_to_include = await self.get_events(state_to_include_ids)
+ state_to_include = await self.get_events(selected_state_ids.values())
return [
{
@@ -1495,21 +1516,15 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
a dict {event_id -> bool}
"""
- # if the event cache contains the event, obviously we've seen it.
-
- cache_results = {
- event_id
- for event_id in event_ids
- if await self._get_event_cache.contains((event_id,))
- }
- results = dict.fromkeys(cache_results, True)
- remaining = [
- event_id for event_id in event_ids if event_id not in cache_results
- ]
- if not remaining:
- return results
+ # TODO: We used to query the _get_event_cache here as a fast-path before
+ # hitting the database. For if an event were in the cache, we've presumably
+ # seen it before.
+ #
+ # But this is currently an invalid assumption due to the _get_event_cache
+ # not being invalidated when purging events from a room. The optimisation can
+ # be re-added after https://github.com/matrix-org/synapse/issues/13476
- def have_seen_events_txn(txn: LoggingTransaction) -> None:
+ def have_seen_events_txn(txn: LoggingTransaction) -> Dict[str, bool]:
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@@ -1517,16 +1532,17 @@ class EventsWorkerStore(SQLBaseStore):
sql = "SELECT event_id FROM events AS e WHERE "
clause, args = make_in_list_sql_clause(
- txn.database_engine, "e.event_id", remaining
+ txn.database_engine, "e.event_id", event_ids
)
txn.execute(sql + clause, args)
found_events = {eid for eid, in txn}
# ... and then we can update the results for each key
- results.update({eid: (eid in found_events) for eid in remaining})
+ return {eid: (eid in found_events) for eid in event_ids}
- await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn)
- return results
+ return await self.db_pool.runInteraction(
+ "have_seen_events", have_seen_events_txn
+ )
@cached(max_entries=100000, tree=True)
async def have_seen_event(self, room_id: str, event_id: str) -> bool:
@@ -1571,7 +1587,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id: The room ID to query.
Returns:
- dict[str:float] of complexity version to complexity.
+ Map of complexity version to complexity.
"""
state_events = await self.get_current_state_event_counts(room_id)
@@ -1969,12 +1985,17 @@ class EventsWorkerStore(SQLBaseStore):
Args:
room_id: room where the event lives
- event_id: event to check
+ event: event to check (can't be an `outlier`)
Returns:
Boolean indicating whether it's an extremity
"""
+ assert not event.internal_metadata.is_outlier(), (
+ "is_event_next_to_backward_gap(...) can't be used with `outlier` events. "
+ "This function relies on `event_backward_extremities` which won't be filled in for `outliers`."
+ )
+
def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool:
# If the event in question has any of its prev_events listed as a
# backward extremity, it's next to a gap.
@@ -2024,12 +2045,17 @@ class EventsWorkerStore(SQLBaseStore):
Args:
room_id: room where the event lives
- event_id: event to check
+ event: event to check (can't be an `outlier`)
Returns:
Boolean indicating whether it's an extremity
"""
+ assert not event.internal_metadata.is_outlier(), (
+ "is_event_next_to_forward_gap(...) can't be used with `outlier` events. "
+ "This function relies on `event_edges` and `event_forward_extremities` which won't be filled in for `outliers`."
+ )
+
def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool:
# If the event in question is a forward extremity, we will just
# consider any potential forward gap as not a gap since it's one of
@@ -2110,13 +2136,33 @@ class EventsWorkerStore(SQLBaseStore):
The closest event_id otherwise None if we can't find any event in
the given direction.
"""
+ if direction == "b":
+ # Find closest event *before* a given timestamp. We use descending
+ # (which gives values largest to smallest) because we want the
+ # largest possible timestamp *before* the given timestamp.
+ comparison_operator = "<="
+ order = "DESC"
+ else:
+ # Find closest event *after* a given timestamp. We use ascending
+ # (which gives values smallest to largest) because we want the
+ # closest possible timestamp *after* the given timestamp.
+ comparison_operator = ">="
+ order = "ASC"
- sql_template = """
+ sql_template = f"""
SELECT event_id FROM events
LEFT JOIN rejections USING (event_id)
WHERE
- origin_server_ts %s ?
- AND room_id = ?
+ room_id = ?
+ AND origin_server_ts {comparison_operator} ?
+ /**
+ * Make sure the event isn't an `outlier` because we have no way
+ * to later check whether it's next to a gap. `outliers` do not
+ * have entries in the `event_edges`, `event_forward_extremeties`,
+ * and `event_backward_extremities` tables to check against
+ * (used by `is_event_next_to_backward_gap` and `is_event_next_to_forward_gap`).
+ */
+ AND NOT outlier
/* Make sure event is not rejected */
AND rejections.event_id IS NULL
/**
@@ -2126,27 +2172,14 @@ class EventsWorkerStore(SQLBaseStore):
* Finally, we can tie-break based on when it was received on the server
* (`stream_ordering`).
*/
- ORDER BY origin_server_ts %s, depth %s, stream_ordering %s
+ ORDER BY origin_server_ts {order}, depth {order}, stream_ordering {order}
LIMIT 1;
"""
def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
- if direction == "b":
- # Find closest event *before* a given timestamp. We use descending
- # (which gives values largest to smallest) because we want the
- # largest possible timestamp *before* the given timestamp.
- comparison_operator = "<="
- order = "DESC"
- else:
- # Find closest event *after* a given timestamp. We use ascending
- # (which gives values smallest to largest) because we want the
- # closest possible timestamp *after* the given timestamp.
- comparison_operator = ">="
- order = "ASC"
-
txn.execute(
- sql_template % (comparison_operator, order, order, order),
- (timestamp, room_id),
+ sql_template,
+ (room_id, timestamp),
)
row = txn.fetchone()
if row:
@@ -2200,7 +2233,15 @@ class EventsWorkerStore(SQLBaseStore):
return result is not None
async def get_partial_state_events_batch(self, room_id: str) -> List[str]:
- """Get a list of events in the given room that have partial state"""
+ """
+ Get a list of events in the given room that:
+ - have partial state; and
+ - are ready to be resynced (because they have no prev_events that are
+ partial-stated)
+
+ See the docstring on `_get_partial_state_events_batch_txn` for more
+ information.
+ """
return await self.db_pool.runInteraction(
"get_partial_state_events_batch",
self._get_partial_state_events_batch_txn,
|