diff options
-rw-r--r-- | changelog.d/15233.misc | 1 | ||||
-rw-r--r-- | changelog.d/15737.feature | 1 | ||||
-rw-r--r-- | changelog.d/15755.misc | 1 | ||||
-rw-r--r-- | changelog.d/15758.bugfix | 1 | ||||
-rw-r--r-- | changelog.d/15770.bugfix | 1 | ||||
-rw-r--r-- | changelog.d/15772.doc | 1 | ||||
-rw-r--r-- | synapse/events/snapshot.py | 159 | ||||
-rw-r--r-- | synapse/federation/federation_client.py | 5 | ||||
-rw-r--r-- | synapse/federation/federation_server.py | 4 | ||||
-rw-r--r-- | synapse/handlers/pagination.py | 137 | ||||
-rw-r--r-- | synapse/storage/controllers/persist_events.py | 5 | ||||
-rw-r--r-- | synapse/storage/databases/main/events.py | 15 | ||||
-rw-r--r-- | synapse/storage/databases/main/events_worker.py | 2 | ||||
-rw-r--r-- | synapse/util/__init__.py | 5 | ||||
-rw-r--r-- | synapse/util/caches/lrucache.py | 8 | ||||
-rw-r--r-- | tests/events/test_snapshot.py | 3 | ||||
-rw-r--r-- | tests/storage/databases/main/test_events_worker.py | 49 | ||||
-rw-r--r-- | tests/storage/test_event_chain.py | 5 | ||||
-rw-r--r-- | tests/test_state.py | 11 |
19 files changed, 327 insertions, 87 deletions
diff --git a/changelog.d/15233.misc b/changelog.d/15233.misc new file mode 100644 index 0000000000..1dff00bf3c --- /dev/null +++ b/changelog.d/15233.misc @@ -0,0 +1 @@ +Replace `EventContext` fields `prev_group` and `delta_ids` with field `state_group_deltas`. diff --git a/changelog.d/15737.feature b/changelog.d/15737.feature new file mode 100644 index 0000000000..9a547b5ebd --- /dev/null +++ b/changelog.d/15737.feature @@ -0,0 +1 @@ +Improve `/messages` response time by avoiding backfill when we already have messages to return. diff --git a/changelog.d/15755.misc b/changelog.d/15755.misc new file mode 100644 index 0000000000..a65340d380 --- /dev/null +++ b/changelog.d/15755.misc @@ -0,0 +1 @@ +Fix requesting multiple keys at once over federation, related to [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983). diff --git a/changelog.d/15758.bugfix b/changelog.d/15758.bugfix new file mode 100644 index 0000000000..cabe25ca24 --- /dev/null +++ b/changelog.d/15758.bugfix @@ -0,0 +1 @@ +Avoid invalidating a cache that was just prefilled. diff --git a/changelog.d/15770.bugfix b/changelog.d/15770.bugfix new file mode 100644 index 0000000000..a65340d380 --- /dev/null +++ b/changelog.d/15770.bugfix @@ -0,0 +1 @@ +Fix requesting multiple keys at once over federation, related to [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983). diff --git a/changelog.d/15772.doc b/changelog.d/15772.doc new file mode 100644 index 0000000000..4d6c933c71 --- /dev/null +++ b/changelog.d/15772.doc @@ -0,0 +1 @@ +Document `looping_call()` functionality that will wait for the given function to finish before scheduling another. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index e7e8225b8e..a43498ed4d 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import attr from immutabledict import immutabledict @@ -107,33 +107,32 @@ class EventContext(UnpersistedEventContextBase): state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None then this is the delta of the state between the two groups. - prev_group: If it is known, ``state_group``'s prev_group. Note that this being - None does not necessarily mean that ``state_group`` does not have - a prev_group! + state_group_deltas: If not empty, this is a dict collecting a mapping of the state + difference between state groups. - If the event is a state event, this is normally the same as - ``state_group_before_event``. + The keys are a tuple of two integers: the initial group and final state group. + The corresponding value is a state map representing the state delta between + these state groups. - If ``state_group`` is None (ie, the event is an outlier), ``prev_group`` - will always also be ``None``. + The dictionary is expected to have at most two entries with state groups of: - Note that this *not* (necessarily) the state group associated with - ``_prev_state_ids``. + 1. The state group before the event and after the event. + 2. The state group preceding the state group before the event and the + state group before the event. - delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group`` - and ``state_group``. + This information is collected and stored as part of an optimization for persisting + events. partial_state: if True, we may be storing this event with a temporary, incomplete state. """ _storage: "StorageControllers" + state_group_deltas: Dict[Tuple[int, int], StateMap[str]] rejected: Optional[str] = None _state_group: Optional[int] = None state_group_before_event: Optional[int] = None _state_delta_due_to_event: Optional[StateMap[str]] = None - prev_group: Optional[int] = None - delta_ids: Optional[StateMap[str]] = None app_service: Optional[ApplicationService] = None partial_state: bool = False @@ -145,16 +144,14 @@ class EventContext(UnpersistedEventContextBase): state_group_before_event: Optional[int], state_delta_due_to_event: Optional[StateMap[str]], partial_state: bool, - prev_group: Optional[int] = None, - delta_ids: Optional[StateMap[str]] = None, + state_group_deltas: Dict[Tuple[int, int], StateMap[str]], ) -> "EventContext": return EventContext( storage=storage, state_group=state_group, state_group_before_event=state_group_before_event, state_delta_due_to_event=state_delta_due_to_event, - prev_group=prev_group, - delta_ids=delta_ids, + state_group_deltas=state_group_deltas, partial_state=partial_state, ) @@ -163,7 +160,7 @@ class EventContext(UnpersistedEventContextBase): storage: "StorageControllers", ) -> "EventContext": """Return an EventContext instance suitable for persisting an outlier event""" - return EventContext(storage=storage) + return EventContext(storage=storage, state_group_deltas={}) async def persist(self, event: EventBase) -> "EventContext": return self @@ -183,13 +180,15 @@ class EventContext(UnpersistedEventContextBase): "state_group": self._state_group, "state_group_before_event": self.state_group_before_event, "rejected": self.rejected, - "prev_group": self.prev_group, + "state_group_deltas": _encode_state_group_delta(self.state_group_deltas), "state_delta_due_to_event": _encode_state_dict( self._state_delta_due_to_event ), - "delta_ids": _encode_state_dict(self.delta_ids), "app_service_id": self.app_service.id if self.app_service else None, "partial_state": self.partial_state, + # add dummy delta_ids and prev_group for backwards compatibility + "delta_ids": None, + "prev_group": None, } @staticmethod @@ -204,17 +203,24 @@ class EventContext(UnpersistedEventContextBase): Returns: The event context. """ + # workaround for backwards/forwards compatibility: if the input doesn't have a value + # for "state_group_deltas" just assign an empty dict + state_group_deltas = input.get("state_group_deltas", None) + if state_group_deltas: + state_group_deltas = _decode_state_group_delta(state_group_deltas) + else: + state_group_deltas = {} + context = EventContext( # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. storage=storage, state_group=input["state_group"], state_group_before_event=input["state_group_before_event"], - prev_group=input["prev_group"], + state_group_deltas=state_group_deltas, state_delta_due_to_event=_decode_state_dict( input["state_delta_due_to_event"] ), - delta_ids=_decode_state_dict(input["delta_ids"]), rejected=input["rejected"], partial_state=input.get("partial_state", False), ) @@ -349,7 +355,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase): _storage: "StorageControllers" state_group_before_event: Optional[int] state_group_after_event: Optional[int] - state_delta_due_to_event: Optional[dict] + state_delta_due_to_event: Optional[StateMap[str]] prev_group_for_state_group_before_event: Optional[int] delta_ids_to_state_group_before_event: Optional[StateMap[str]] partial_state: bool @@ -380,26 +386,16 @@ class UnpersistedEventContext(UnpersistedEventContextBase): events_and_persisted_context = [] for event, unpersisted_context in amended_events_and_context: - if event.is_state(): - context = EventContext( - storage=unpersisted_context._storage, - state_group=unpersisted_context.state_group_after_event, - state_group_before_event=unpersisted_context.state_group_before_event, - state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, - partial_state=unpersisted_context.partial_state, - prev_group=unpersisted_context.state_group_before_event, - delta_ids=unpersisted_context.state_delta_due_to_event, - ) - else: - context = EventContext( - storage=unpersisted_context._storage, - state_group=unpersisted_context.state_group_after_event, - state_group_before_event=unpersisted_context.state_group_before_event, - state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, - partial_state=unpersisted_context.partial_state, - prev_group=unpersisted_context.prev_group_for_state_group_before_event, - delta_ids=unpersisted_context.delta_ids_to_state_group_before_event, - ) + state_group_deltas = unpersisted_context._build_state_group_deltas() + + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + state_group_deltas=state_group_deltas, + ) events_and_persisted_context.append((event, context)) return events_and_persisted_context @@ -452,11 +448,11 @@ class UnpersistedEventContext(UnpersistedEventContextBase): # if the event isn't a state event the state group doesn't change if not self.state_delta_due_to_event: - state_group_after_event = self.state_group_before_event + self.state_group_after_event = self.state_group_before_event # otherwise if it is a state event we need to get a state group for it else: - state_group_after_event = await self._storage.state.store_state_group( + self.state_group_after_event = await self._storage.state.store_state_group( event.event_id, event.room_id, prev_group=self.state_group_before_event, @@ -464,16 +460,81 @@ class UnpersistedEventContext(UnpersistedEventContextBase): current_state_ids=None, ) + state_group_deltas = self._build_state_group_deltas() + return EventContext.with_state( storage=self._storage, - state_group=state_group_after_event, + state_group=self.state_group_after_event, state_group_before_event=self.state_group_before_event, state_delta_due_to_event=self.state_delta_due_to_event, + state_group_deltas=state_group_deltas, partial_state=self.partial_state, - prev_group=self.state_group_before_event, - delta_ids=self.state_delta_due_to_event, ) + def _build_state_group_deltas(self) -> Dict[Tuple[int, int], StateMap]: + """ + Collect deltas between the state groups associated with this context + """ + state_group_deltas = {} + + # if we know the state group before the event and after the event, add them and the + # state delta between them to state_group_deltas + if self.state_group_before_event and self.state_group_after_event: + # if we have the state groups we should have the delta + assert self.state_delta_due_to_event is not None + state_group_deltas[ + ( + self.state_group_before_event, + self.state_group_after_event, + ) + ] = self.state_delta_due_to_event + + # the state group before the event may also have a state group which precedes it, if + # we have that and the state group before the event, add them and the state + # delta between them to state_group_deltas + if ( + self.prev_group_for_state_group_before_event + and self.state_group_before_event + ): + # if we have both state groups we should have the delta between them + assert self.delta_ids_to_state_group_before_event is not None + state_group_deltas[ + ( + self.prev_group_for_state_group_before_event, + self.state_group_before_event, + ) + ] = self.delta_ids_to_state_group_before_event + + return state_group_deltas + + +def _encode_state_group_delta( + state_group_delta: Dict[Tuple[int, int], StateMap[str]] +) -> List[Tuple[int, int, Optional[List[Tuple[str, str, str]]]]]: + if not state_group_delta: + return [] + + state_group_delta_encoded = [] + for key, value in state_group_delta.items(): + state_group_delta_encoded.append((key[0], key[1], _encode_state_dict(value))) + + return state_group_delta_encoded + + +def _decode_state_group_delta( + input: List[Tuple[int, int, List[Tuple[str, str, str]]]] +) -> Dict[Tuple[int, int], StateMap[str]]: + if not input: + return {} + + state_group_deltas = {} + for state_group_1, state_group_2, state_dict in input: + state_map = _decode_state_dict(state_dict) + assert state_map is not None + state_group_deltas[(state_group_1, state_group_2)] = state_map + + return state_group_deltas + def _encode_state_dict( state_dict: Optional[StateMap[str]], diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index a2cf3a96c6..e5359ca558 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -260,7 +260,9 @@ class FederationClient(FederationBase): use_unstable = False for user_id, one_time_keys in query.items(): for device_id, algorithms in one_time_keys.items(): - if any(count > 1 for count in algorithms.values()): + # If more than one algorithm is requested, attempt to use the unstable + # endpoint. + if sum(algorithms.values()) > 1: use_unstable = True if algorithms: # For the stable query, choose only the first algorithm. @@ -296,6 +298,7 @@ class FederationClient(FederationBase): else: logger.debug("Skipping unstable claim client keys API") + # TODO Potentially attempt multiple queries and combine the results? return await self.transport_layer.claim_client_keys( user, destination, content, timeout ) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 9425b32507..61fa3b30af 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1016,7 +1016,9 @@ class FederationServer(FederationBase): for user_id, device_keys in result.items(): for device_id, keys in device_keys.items(): for key_id, key in keys.items(): - json_result.setdefault(user_id, {})[device_id] = {key_id: key} + json_result.setdefault(user_id, {}).setdefault(device_id, {})[ + key_id + ] = key logger.info( "Claimed one-time-keys: %s", diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index d5257acb7d..19b8728db9 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -40,6 +40,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# How many single event gaps we tolerate returning in a `/messages` response before we +# backfill and try to fill in the history. This is an arbitrarily picked number so feel +# free to tune it in the future. +BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3 + @attr.s(slots=True, auto_attribs=True) class PurgeStatus: @@ -486,35 +491,35 @@ class PaginationHandler: room_id, room_token.stream ) - if not use_admin_priviledge and membership == Membership.LEAVE: - # If they have left the room then clamp the token to be before - # they left the room, to save the effort of loading from the - # database. - - # This is only None if the room is world_readable, in which - # case "JOIN" would have been returned. - assert member_event_id + # If they have left the room then clamp the token to be before + # they left the room, to save the effort of loading from the + # database. + if ( + pagin_config.direction == Direction.BACKWARDS + and not use_admin_priviledge + and membership == Membership.LEAVE + ): + # This is only None if the room is world_readable, in which case + # "Membership.JOIN" would have been returned and we should never hit + # this branch. + assert member_event_id + + leave_token = await self.store.get_topological_token_for_event( + member_event_id + ) + assert leave_token.topological is not None - leave_token = await self.store.get_topological_token_for_event( - member_event_id + if leave_token.topological < curr_topo: + from_token = from_token.copy_and_replace( + StreamKeyType.ROOM, leave_token ) - assert leave_token.topological is not None - - if leave_token.topological < curr_topo: - from_token = from_token.copy_and_replace( - StreamKeyType.ROOM, leave_token - ) - - await self.hs.get_federation_handler().maybe_backfill( - room_id, - curr_topo, - limit=pagin_config.limit, - ) to_room_key = None if pagin_config.to_token: to_room_key = pagin_config.to_token.room_key + # Initially fetch the events from the database. With any luck, we can return + # these without blocking on backfill (handled below). events, next_key = await self.store.paginate_room_events( room_id=room_id, from_key=from_token.room_key, @@ -524,6 +529,94 @@ class PaginationHandler: event_filter=event_filter, ) + if pagin_config.direction == Direction.BACKWARDS: + # We use a `Set` because there can be multiple events at a given depth + # and we only care about looking at the unique continum of depths to + # find gaps. + event_depths: Set[int] = {event.depth for event in events} + sorted_event_depths = sorted(event_depths) + + # Inspect the depths of the returned events to see if there are any gaps + found_big_gap = False + number_of_gaps = 0 + previous_event_depth = ( + sorted_event_depths[0] if len(sorted_event_depths) > 0 else 0 + ) + for event_depth in sorted_event_depths: + # We don't expect a negative depth but we'll just deal with it in + # any case by taking the absolute value to get the true gap between + # any two integers. + depth_gap = abs(event_depth - previous_event_depth) + # A `depth_gap` of 1 is a normal continuous chain to the next event + # (1 <-- 2 <-- 3) so anything larger indicates a missing event (it's + # also possible there is no event at a given depth but we can't ever + # know that for sure) + if depth_gap > 1: + number_of_gaps += 1 + + # We only tolerate a small number single-event long gaps in the + # returned events because those are most likely just events we've + # failed to pull in the past. Anything longer than that is probably + # a sign that we're missing a decent chunk of history and we should + # try to backfill it. + # + # XXX: It's possible we could tolerate longer gaps if we checked + # that a given events `prev_events` is one that has failed pull + # attempts and we could just treat it like a dead branch of history + # for now or at least something that we don't need the block the + # client on to try pulling. + # + # XXX: If we had something like MSC3871 to indicate gaps in the + # timeline to the client, we could also get away with any sized gap + # and just have the client refetch the holes as they see fit. + if depth_gap > 2: + found_big_gap = True + break + previous_event_depth = event_depth + + # Backfill in the foreground if we found a big gap, have too many holes, + # or we don't have enough events to fill the limit that the client asked + # for. + missing_too_many_events = ( + number_of_gaps > BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD + ) + not_enough_events_to_fill_response = len(events) < pagin_config.limit + if ( + found_big_gap + or missing_too_many_events + or not_enough_events_to_fill_response + ): + did_backfill = ( + await self.hs.get_federation_handler().maybe_backfill( + room_id, + curr_topo, + limit=pagin_config.limit, + ) + ) + + # If we did backfill something, refetch the events from the database to + # catch anything new that might have been added since we last fetched. + if did_backfill: + events, next_key = await self.store.paginate_room_events( + room_id=room_id, + from_key=from_token.room_key, + to_key=to_room_key, + direction=pagin_config.direction, + limit=pagin_config.limit, + event_filter=event_filter, + ) + else: + # Otherwise, we can backfill in the background for eventual + # consistency's sake but we don't need to block the client waiting + # for a costly federation call and processing. + run_as_background_process( + "maybe_backfill_in_the_background", + self.hs.get_federation_handler().maybe_backfill, + room_id, + curr_topo, + limit=pagin_config.limit, + ) + next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key) # if no events are returned from pagination, that implies diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index f1d2c71c91..35c0680365 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -839,9 +839,8 @@ class EventsPersistenceStorageController: "group" % (ev.event_id,) ) continue - - if ctx.prev_group: - state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids + if ctx.state_group_deltas: + state_group_deltas.update(ctx.state_group_deltas) # We need to map the event_ids to their state groups. First, let's # check if the event is one we're persisting, in which case we can diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index e2e6eb479f..44af3357af 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1729,13 +1729,22 @@ class PersistEventsStore: if not row["rejects"] and not row["redacts"]: to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) - async def prefill() -> None: + async def external_prefill() -> None: for cache_entry in to_prefill: - await self.store._get_event_cache.set( + await self.store._get_event_cache.set_external( (cache_entry.event.event_id,), cache_entry ) - txn.async_call_after(prefill) + def local_prefill() -> None: + for cache_entry in to_prefill: + self.store._get_event_cache.set_local( + (cache_entry.event.event_id,), cache_entry + ) + + # The order these are called here is not as important as knowing that after the + # transaction is finished, the async_call_after will run before the call_after. + txn.async_call_after(external_prefill) + txn.call_after(local_prefill) def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: assert event.redacts is not None diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index d93ffc4efa..7e7648c951 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -883,7 +883,7 @@ class EventsWorkerStore(SQLBaseStore): async def _invalidate_async_get_event_cache(self, event_id: str) -> None: """ - Invalidates an event in the asyncronous get event cache, which may be remote. + Invalidates an event in the asynchronous get event cache, which may be remote. Arguments: event_id: the event ID to invalidate diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 7ea0c4c36b..9f3b8741c1 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -116,6 +116,11 @@ class Clock: Waits `msec` initially before calling `f` for the first time. + If the function given to `looping_call` returns an awaitable/deferred, the next + call isn't scheduled until after the returned awaitable has finished. We get + this functionality thanks to this function being a thin wrapper around + `twisted.internet.task.LoopingCall`. + Note that the function will be called with no logcontext, so if it is anything other than trivial, you probably want to wrap it in run_as_background_process. diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 6137c85e10..be6554319a 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -842,7 +842,13 @@ class AsyncLruCache(Generic[KT, VT]): return self._lru_cache.get(key, update_metrics=update_metrics) async def set(self, key: KT, value: VT) -> None: - self._lru_cache.set(key, value) + # This will add the entries in the correct order, local first external second + self.set_local(key, value) + await self.set_external(key, value) + + async def set_external(self, key: KT, value: VT) -> None: + # This method should add an entry to any configured external cache, in this case noop. + pass def set_local(self, key: KT, value: VT) -> None: self._lru_cache.set(key, value) diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 6687c28e8f..b5e42f9600 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -101,8 +101,7 @@ class TestEventContext(unittest.HomeserverTestCase): self.assertEqual( context.state_group_before_event, d_context.state_group_before_event ) - self.assertEqual(context.prev_group, d_context.prev_group) - self.assertEqual(context.delta_ids, d_context.delta_ids) + self.assertEqual(context.state_group_deltas, d_context.state_group_deltas) self.assertEqual(context.app_service, d_context.app_service) self.assertEqual( diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 788500e38f..b223dc750b 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -139,6 +139,55 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # That should result in a single db query to lookup self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + def test_persisting_event_prefills_get_event_cache(self) -> None: + """ + Test to make sure that the `_get_event_cache` is prefilled after we persist an + event and returns the updated value. + """ + event, event_context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + sender=self.user, + type="test_event_type", + content={"body": "conflabulation"}, + ) + ) + + # First, check `_get_event_cache` for the event we just made + # to verify it's not in the cache. + res = self.store._get_event_cache.get_local((event.event_id,)) + self.assertEqual(res, None, "Event was cached when it should not have been.") + + with LoggingContext(name="test") as ctx: + # Persist the event which should invalidate then prefill the + # `_get_event_cache` so we don't return stale values. + # Side Note: Apparently, persisting an event isn't a transaction in the + # sense that it is recorded in the LoggingContext + persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None + self.get_success( + persistence.persist_event( + event, + event_context, + ) + ) + + # Check `_get_event_cache` again and we should see the updated fact + # that we now have the event cached after persisting it. + res = self.store._get_event_cache.get_local((event.event_id,)) + self.assertEqual(res.event, event, "Event not cached as expected.") # type: ignore + + # Try and fetch the event from the database. + self.get_success(self.store.get_event(event.event_id)) + + # Verify that the database hit was avoided. + self.assertEqual( + ctx.get_resource_usage().evt_db_fetch_count, + 0, + "Database was hit, which would not happen if event was cached.", + ) + def test_invalidate_cache_by_room_id(self) -> None: """ Test to make sure that all events associated with the given `(room_id,)` diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index e39b63edac..48ebfadaab 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -401,7 +401,10 @@ class EventChainStoreTestCase(HomeserverTestCase): assert persist_events_store is not None persist_events_store._store_event_txn( txn, - [(e, EventContext(self.hs.get_storage_controllers())) for e in events], + [ + (e, EventContext(self.hs.get_storage_controllers(), {})) + for e in events + ], ) # Actually call the function that calculates the auth chain stuff. diff --git a/tests/test_state.py b/tests/test_state.py index 7a49b87953..eded38c766 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -555,10 +555,15 @@ class StateTestCase(unittest.TestCase): (e.event_id for e in old_state + [event]), current_state_ids.values() ) - self.assertIsNotNone(context.state_group_before_event) + assert context.state_group_before_event is not None + assert context.state_group is not None + self.assertEqual( + context.state_group_deltas.get( + (context.state_group_before_event, context.state_group) + ), + {(event.type, event.state_key): event.event_id}, + ) self.assertNotEqual(context.state_group_before_event, context.state_group) - self.assertEqual(context.state_group_before_event, context.prev_group) - self.assertEqual({("state", ""): event.event_id}, context.delta_ids) @defer.inlineCallbacks def test_trivial_annotate_message( |