summary refs log tree commit diff
path: root/synapse/storage/databases/main/events_worker.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/events_worker.py')
-rw-r--r--synapse/storage/databases/main/events_worker.py203
1 files changed, 122 insertions, 81 deletions
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 7cdc9fe98f..318fd7dc71 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -16,11 +16,11 @@ import logging
 import threading
 import weakref
 from enum import Enum, auto
+from itertools import chain
 from typing import (
     TYPE_CHECKING,
     Any,
     Collection,
-    Container,
     Dict,
     Iterable,
     List,
@@ -59,7 +59,6 @@ from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
 )
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -77,10 +76,12 @@ from synapse.storage.util.id_generators import (
 )
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
+from synapse.types.state import StateFilter
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import AsyncLruCache
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
@@ -212,26 +213,35 @@ class EventsWorkerStore(SQLBaseStore):
             # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
             # updated over replication. (Multiple writers are not supported for
             # SQLite).
-            if hs.get_instance_name() in hs.config.worker.writers.events:
-                self._stream_id_gen = StreamIdGenerator(
-                    db_conn,
-                    "events",
-                    "stream_ordering",
-                )
-                self._backfill_id_gen = StreamIdGenerator(
-                    db_conn,
-                    "events",
-                    "stream_ordering",
-                    step=-1,
-                    extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
-                )
-            else:
-                self._stream_id_gen = SlavedIdTracker(
-                    db_conn, "events", "stream_ordering"
-                )
-                self._backfill_id_gen = SlavedIdTracker(
-                    db_conn, "events", "stream_ordering", step=-1
-                )
+            self._stream_id_gen = StreamIdGenerator(
+                db_conn,
+                "events",
+                "stream_ordering",
+                is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
+            )
+            self._backfill_id_gen = StreamIdGenerator(
+                db_conn,
+                "events",
+                "stream_ordering",
+                step=-1,
+                extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+                is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
+            )
+
+        events_max = self._stream_id_gen.get_current_token()
+        curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
+            db_conn,
+            "current_state_delta_stream",
+            entity_column="room_id",
+            stream_column="stream_id",
+            max_value=events_max,  # As we share the stream id with events token
+            limit=1000,
+        )
+        self._curr_state_delta_stream_cache: StreamChangeCache = StreamChangeCache(
+            "_curr_state_delta_stream_cache",
+            min_curr_state_delta_id,
+            prefilled_cache=curr_state_delta_prefill,
+        )
 
         if hs.config.worker.run_background_tasks:
             # We periodically clean out old transaction ID mappings
@@ -374,7 +384,7 @@ class EventsWorkerStore(SQLBaseStore):
                 If there is a mismatch, behave as per allow_none.
 
         Returns:
-            The event, or None if the event was not found.
+            The event, or None if the event was not found and allow_none is `True`.
         """
         if not isinstance(event_id, str):
             raise TypeError("Invalid event event_id %r" % (event_id,))
@@ -474,7 +484,7 @@ class EventsWorkerStore(SQLBaseStore):
             return []
 
         # there may be duplicates so we cast the list to a set
-        event_entry_map = await self._get_events_from_cache_or_db(
+        event_entry_map = await self.get_unredacted_events_from_cache_or_db(
             set(event_ids), allow_rejected=allow_rejected
         )
 
@@ -509,7 +519,9 @@ class EventsWorkerStore(SQLBaseStore):
                     continue
 
                 redacted_event_id = entry.event.redacts
-                event_map = await self._get_events_from_cache_or_db([redacted_event_id])
+                event_map = await self.get_unredacted_events_from_cache_or_db(
+                    [redacted_event_id]
+                )
                 original_event_entry = event_map.get(redacted_event_id)
                 if not original_event_entry:
                     # we don't have the redacted event (or it was rejected).
@@ -588,11 +600,16 @@ class EventsWorkerStore(SQLBaseStore):
         return events
 
     @cancellable
-    async def _get_events_from_cache_or_db(
-        self, event_ids: Iterable[str], allow_rejected: bool = False
+    async def get_unredacted_events_from_cache_or_db(
+        self,
+        event_ids: Iterable[str],
+        allow_rejected: bool = False,
     ) -> Dict[str, EventCacheEntry]:
         """Fetch a bunch of events from the cache or the database.
 
+        Note that the events pulled by this function will not have any redactions
+        applied, and no guarantee is made about the ordering of the events returned.
+
         If events are pulled from the database, they will be cached for future lookups.
 
         Unknown events are omitted from the response.
@@ -863,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore):
     async def get_stripped_room_state_from_event_context(
         self,
         context: EventContext,
-        state_types_to_include: Container[str],
+        state_keys_to_include: StateFilter,
         membership_user_id: Optional[str] = None,
     ) -> List[JsonDict]:
         """
@@ -876,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         Args:
             context: The event context to retrieve state of the room from.
-            state_types_to_include: The type of state events to include.
+            state_keys_to_include: The state events to include, for each event type.
             membership_user_id: An optional user ID to include the stripped membership state
                 events of. This is useful when generating the stripped state of a room for
                 invites. We want to send membership events of the inviter, so that the
@@ -885,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             A list of dictionaries, each representing a stripped state event from the room.
         """
-        current_state_ids = await context.get_current_state_ids()
+        if membership_user_id:
+            types = chain(
+                state_keys_to_include.to_types(),
+                [(EventTypes.Member, membership_user_id)],
+            )
+            filter = StateFilter.from_types(types)
+        else:
+            filter = state_keys_to_include
+        selected_state_ids = await context.get_current_state_ids(filter)
 
         # We know this event is not an outlier, so this must be
         # non-None.
-        assert current_state_ids is not None
-
-        # The state to include
-        state_to_include_ids = [
-            e_id
-            for k, e_id in current_state_ids.items()
-            if k[0] in state_types_to_include
-            or (membership_user_id and k == (EventTypes.Member, membership_user_id))
-        ]
+        assert selected_state_ids is not None
+
+        # Confusingly, get_current_state_events may return events that are discarded by
+        # the filter, if they're in context._state_delta_due_to_event. Strip these away.
+        selected_state_ids = filter.filter_state(selected_state_ids)
 
-        state_to_include = await self.get_events(state_to_include_ids)
+        state_to_include = await self.get_events(selected_state_ids.values())
 
         return [
             {
@@ -1495,21 +1516,15 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
              a dict {event_id -> bool}
         """
-        # if the event cache contains the event, obviously we've seen it.
-
-        cache_results = {
-            event_id
-            for event_id in event_ids
-            if await self._get_event_cache.contains((event_id,))
-        }
-        results = dict.fromkeys(cache_results, True)
-        remaining = [
-            event_id for event_id in event_ids if event_id not in cache_results
-        ]
-        if not remaining:
-            return results
+        # TODO: We used to query the _get_event_cache here as a fast-path before
+        #  hitting the database. For if an event were in the cache, we've presumably
+        #  seen it before.
+        #
+        #  But this is currently an invalid assumption due to the _get_event_cache
+        #  not being invalidated when purging events from a room. The optimisation can
+        #  be re-added after https://github.com/matrix-org/synapse/issues/13476
 
-        def have_seen_events_txn(txn: LoggingTransaction) -> None:
+        def have_seen_events_txn(txn: LoggingTransaction) -> Dict[str, bool]:
             # we deliberately do *not* query the database for room_id, to make the
             # query an index-only lookup on `events_event_id_key`.
             #
@@ -1517,16 +1532,17 @@ class EventsWorkerStore(SQLBaseStore):
 
             sql = "SELECT event_id FROM events AS e WHERE "
             clause, args = make_in_list_sql_clause(
-                txn.database_engine, "e.event_id", remaining
+                txn.database_engine, "e.event_id", event_ids
             )
             txn.execute(sql + clause, args)
             found_events = {eid for eid, in txn}
 
             # ... and then we can update the results for each key
-            results.update({eid: (eid in found_events) for eid in remaining})
+            return {eid: (eid in found_events) for eid in event_ids}
 
-        await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn)
-        return results
+        return await self.db_pool.runInteraction(
+            "have_seen_events", have_seen_events_txn
+        )
 
     @cached(max_entries=100000, tree=True)
     async def have_seen_event(self, room_id: str, event_id: str) -> bool:
@@ -1571,7 +1587,7 @@ class EventsWorkerStore(SQLBaseStore):
             room_id: The room ID to query.
 
         Returns:
-            dict[str:float] of complexity version to complexity.
+            Map of complexity version to complexity.
         """
         state_events = await self.get_current_state_event_counts(room_id)
 
@@ -1969,12 +1985,17 @@ class EventsWorkerStore(SQLBaseStore):
 
         Args:
             room_id: room where the event lives
-            event_id: event to check
+            event: event to check (can't be an `outlier`)
 
         Returns:
             Boolean indicating whether it's an extremity
         """
 
+        assert not event.internal_metadata.is_outlier(), (
+            "is_event_next_to_backward_gap(...) can't be used with `outlier` events. "
+            "This function relies on `event_backward_extremities` which won't be filled in for `outliers`."
+        )
+
         def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool:
             # If the event in question has any of its prev_events listed as a
             # backward extremity, it's next to a gap.
@@ -2024,12 +2045,17 @@ class EventsWorkerStore(SQLBaseStore):
 
         Args:
             room_id: room where the event lives
-            event_id: event to check
+            event: event to check (can't be an `outlier`)
 
         Returns:
             Boolean indicating whether it's an extremity
         """
 
+        assert not event.internal_metadata.is_outlier(), (
+            "is_event_next_to_forward_gap(...) can't be used with `outlier` events. "
+            "This function relies on `event_edges` and `event_forward_extremities` which won't be filled in for `outliers`."
+        )
+
         def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool:
             # If the event in question is a forward extremity, we will just
             # consider any potential forward gap as not a gap since it's one of
@@ -2110,13 +2136,33 @@ class EventsWorkerStore(SQLBaseStore):
             The closest event_id otherwise None if we can't find any event in
             the given direction.
         """
+        if direction == "b":
+            # Find closest event *before* a given timestamp. We use descending
+            # (which gives values largest to smallest) because we want the
+            # largest possible timestamp *before* the given timestamp.
+            comparison_operator = "<="
+            order = "DESC"
+        else:
+            # Find closest event *after* a given timestamp. We use ascending
+            # (which gives values smallest to largest) because we want the
+            # closest possible timestamp *after* the given timestamp.
+            comparison_operator = ">="
+            order = "ASC"
 
-        sql_template = """
+        sql_template = f"""
             SELECT event_id FROM events
             LEFT JOIN rejections USING (event_id)
             WHERE
-                origin_server_ts %s ?
-                AND room_id = ?
+                room_id = ?
+                AND origin_server_ts {comparison_operator} ?
+                /**
+                 * Make sure the event isn't an `outlier` because we have no way
+                 * to later check whether it's next to a gap. `outliers` do not
+                 * have entries in the `event_edges`, `event_forward_extremeties`,
+                 * and `event_backward_extremities` tables to check against
+                 * (used by `is_event_next_to_backward_gap` and `is_event_next_to_forward_gap`).
+                 */
+                AND NOT outlier
                 /* Make sure event is not rejected */
                 AND rejections.event_id IS NULL
             /**
@@ -2126,27 +2172,14 @@ class EventsWorkerStore(SQLBaseStore):
              * Finally, we can tie-break based on when it was received on the server
              * (`stream_ordering`).
              */
-            ORDER BY origin_server_ts %s, depth %s, stream_ordering %s
+            ORDER BY origin_server_ts {order}, depth {order}, stream_ordering {order}
             LIMIT 1;
         """
 
         def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
-            if direction == "b":
-                # Find closest event *before* a given timestamp. We use descending
-                # (which gives values largest to smallest) because we want the
-                # largest possible timestamp *before* the given timestamp.
-                comparison_operator = "<="
-                order = "DESC"
-            else:
-                # Find closest event *after* a given timestamp. We use ascending
-                # (which gives values smallest to largest) because we want the
-                # closest possible timestamp *after* the given timestamp.
-                comparison_operator = ">="
-                order = "ASC"
-
             txn.execute(
-                sql_template % (comparison_operator, order, order, order),
-                (timestamp, room_id),
+                sql_template,
+                (room_id, timestamp),
             )
             row = txn.fetchone()
             if row:
@@ -2200,7 +2233,15 @@ class EventsWorkerStore(SQLBaseStore):
         return result is not None
 
     async def get_partial_state_events_batch(self, room_id: str) -> List[str]:
-        """Get a list of events in the given room that have partial state"""
+        """
+        Get a list of events in the given room that:
+        - have partial state; and
+        - are ready to be resynced (because they have no prev_events that are
+          partial-stated)
+
+        See the docstring on `_get_partial_state_events_batch_txn` for more
+        information.
+        """
         return await self.db_pool.runInteraction(
             "get_partial_state_events_batch",
             self._get_partial_state_events_batch_txn,