diff options
Diffstat (limited to 'synapse/storage/databases/main/cache.py')
-rw-r--r-- | synapse/storage/databases/main/cache.py | 307 |
1 files changed, 307 insertions, 0 deletions
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py new file mode 100644 index 0000000000..683afde52b --- /dev/null +++ b/synapse/storage/databases/main/cache.py @@ -0,0 +1,307 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging +from typing import Any, Iterable, List, Optional, Tuple + +from synapse.api.constants import EventTypes +from synapse.replication.tcp.streams import BackfillStream, CachesStream +from synapse.replication.tcp.streams.events import ( + EventsStream, + EventsStreamCurrentStateRow, + EventsStreamEventRow, +) +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool +from synapse.storage.engines import PostgresEngine +from synapse.util.iterutils import batch_iter + +logger = logging.getLogger(__name__) + + +# This is a special cache name we use to batch multiple invalidations of caches +# based on the current state when notifying workers over replication. +CURRENT_STATE_CACHE_NAME = "cs_cache_fake" + + +class CacheInvalidationWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self._instance_name = hs.get_instance_name() + + async def get_all_updated_caches( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for caches replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = """ + SELECT stream_id, cache_func, keys, invalidation_ts + FROM cache_invalidation_stream_by_instance + WHERE stream_id > ? AND instance_name = ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, instance_name, limit)) + updates = [(row[0], row[1:]) for row in txn] + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db_pool.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == EventsStream.NAME: + for row in rows: + self._process_event_stream_row(token, row) + elif stream_name == BackfillStream.NAME: + for row in rows: + self._invalidate_caches_for_event( + -token, + row.event_id, + row.room_id, + row.type, + row.state_key, + row.redacts, + row.relates_to, + backfilled=True, + ) + elif stream_name == CachesStream.NAME: + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) + + for row in rows: + if row.cache_func == CURRENT_STATE_CACHE_NAME: + if row.keys is None: + raise Exception( + "Can't send an 'invalidate all' for current state cache" + ) + + room_id = row.keys[0] + members_changed = set(row.keys[1:]) + self._invalidate_state_caches(room_id, members_changed) + else: + self._attempt_to_invalidate_cache(row.cache_func, row.keys) + + super().process_replication_rows(stream_name, instance_name, token, rows) + + def _process_event_stream_row(self, token, row): + data = row.data + + if row.type == EventsStreamEventRow.TypeId: + self._invalidate_caches_for_event( + token, + data.event_id, + data.room_id, + data.type, + data.state_key, + data.redacts, + data.relates_to, + backfilled=False, + ) + elif row.type == EventsStreamCurrentStateRow.TypeId: + self._curr_state_delta_stream_cache.entity_has_changed( + row.data.room_id, token + ) + + if data.type == EventTypes.Member: + self.get_rooms_for_user_with_stream_ordering.invalidate( + (data.state_key,) + ) + else: + raise Exception("Unknown events stream row type %s" % (row.type,)) + + def _invalidate_caches_for_event( + self, + stream_ordering, + event_id, + room_id, + etype, + state_key, + redacts, + relates_to, + backfilled, + ): + self._invalidate_get_event_cache(event_id) + + self.get_latest_event_ids_in_room.invalidate((room_id,)) + + self.get_unread_message_count_for_user.invalidate_many((room_id,)) + self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) + + if not backfilled: + self._events_stream_cache.entity_has_changed(room_id, stream_ordering) + + if redacts: + self._invalidate_get_event_cache(redacts) + + if etype == EventTypes.Member: + self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) + self.get_invited_rooms_for_local_user.invalidate((state_key,)) + + if relates_to: + self.get_relations_for_event.invalidate_many((relates_to,)) + self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) + self.get_applicable_edit.invalidate((relates_to,)) + + async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + cache_func = getattr(self, cache_name, None) + if not cache_func: + return + + cache_func.invalidate(keys) + await self.db_pool.runInteraction( + "invalidate_cache_and_stream", + self._send_invalidation_to_replication, + cache_func.__name__, + keys, + ) + + def _invalidate_cache_and_stream(self, txn, cache_func, keys): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + txn.call_after(cache_func.invalidate, keys) + self._send_invalidation_to_replication(txn, cache_func.__name__, keys) + + def _invalidate_all_cache_and_stream(self, txn, cache_func): + """Invalidates the entire cache and adds it to the cache stream so slaves + will know to invalidate their caches. + """ + + txn.call_after(cache_func.invalidate_all) + self._send_invalidation_to_replication(txn, cache_func.__name__, None) + + def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): + """Special case invalidation of caches based on current state. + + We special case this so that we can batch the cache invalidations into a + single replication poke. + + Args: + txn + room_id (str): Room where state changed + members_changed (iterable[str]): The user_ids of members that have changed + """ + txn.call_after(self._invalidate_state_caches, room_id, members_changed) + + if members_changed: + # We need to be careful that the size of the `members_changed` list + # isn't so large that it causes problems sending over replication, so we + # send them in chunks. + # Max line length is 16K, and max user ID length is 255, so 50 should + # be safe. + for chunk in batch_iter(members_changed, 50): + keys = itertools.chain([room_id], chunk) + self._send_invalidation_to_replication( + txn, CURRENT_STATE_CACHE_NAME, keys + ) + else: + # if no members changed, we still need to invalidate the other caches. + self._send_invalidation_to_replication( + txn, CURRENT_STATE_CACHE_NAME, [room_id] + ) + + def _send_invalidation_to_replication( + self, txn, cache_name: str, keys: Optional[Iterable[Any]] + ): + """Notifies replication that given cache has been invalidated. + + Note that this does *not* invalidate the cache locally. + + Args: + txn + cache_name + keys: Entry to invalidate. If None will invalidate all. + """ + + if cache_name == CURRENT_STATE_CACHE_NAME and keys is None: + raise Exception( + "Can't stream invalidate all with magic current state cache" + ) + + if isinstance(self.database_engine, PostgresEngine): + # get_next() returns a context manager which is designed to wrap + # the transaction. However, we want to only get an ID when we want + # to use it, here, so we need to call __enter__ manually, and have + # __exit__ called after the transaction finishes. + stream_id = self._cache_id_gen.get_next_txn(txn) + txn.call_after(self.hs.get_notifier().on_new_replication_data) + + if keys is not None: + keys = list(keys) + + self.db_pool.simple_insert_txn( + txn, + table="cache_invalidation_stream_by_instance", + values={ + "stream_id": stream_id, + "instance_name": self._instance_name, + "cache_func": cache_name, + "keys": keys, + "invalidation_ts": self.clock.time_msec(), + }, + ) + + def get_cache_stream_token(self, instance_name): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token(instance_name) + else: + return 0 |