diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/_base.py | 58 | ||||
-rw-r--r-- | synapse/storage/events.py | 25 |
2 files changed, 58 insertions, 25 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index e124161845..f1a5366b95 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging import sys import threading @@ -28,6 +29,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import get_domain_from_id from synapse.util.caches.descriptors import Cache from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.stringutils import exception_to_unicode @@ -64,6 +66,10 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "event_search": "event_search_event_id_idx", } +# This is a special cache name we use to batch multiple invalidations of caches +# based on the current state when notifying workers over replication. +_CURRENT_STATE_CACHE_NAME = "cs_cache_fake" + class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object @@ -1184,6 +1190,56 @@ class SQLBaseStore(object): be invalidated. """ txn.call_after(cache_func.invalidate, keys) + self._send_invalidation_to_replication(txn, cache_func.__name__, keys) + + def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): + """Special case invalidation of caches based on current state. + + We special case this so that we can batch the cache invalidations into a + single replication poke. + + Args: + txn + room_id (str): Room where state changed + members_changed (iterable[str]): The user_ids of members that have changed + """ + txn.call_after(self._invalidate_state_caches, room_id, members_changed) + + keys = itertools.chain([room_id], members_changed) + self._send_invalidation_to_replication( + txn, _CURRENT_STATE_CACHE_NAME, keys, + ) + + def _invalidate_state_caches(self, room_id, members_changed): + """Invalidates caches that are based on the current state, but does + not stream invalidations down replication. + + Args: + room_id (str): Room where state changed + members_changed (iterable[str]): The user_ids of members that have + changed + """ + for member in members_changed: + self.get_rooms_for_user_with_stream_ordering.invalidate((member,)) + + for host in set(get_domain_from_id(u) for u in members_changed): + self.is_host_joined.invalidate((room_id, host)) + self.was_host_joined.invalidate((room_id, host)) + + self.get_users_in_room.invalidate((room_id,)) + self.get_room_summary.invalidate((room_id,)) + self.get_current_state_ids.invalidate((room_id,)) + + def _send_invalidation_to_replication(self, txn, cache_name, keys): + """Notifies replication that given cache has been invalidated. + + Note that this does *not* invalidate the cache locally. + + Args: + txn + cache_name (str) + keys (iterable[str]) + """ if isinstance(self.database_engine, PostgresEngine): # get_next() returns a context manager which is designed to wrap @@ -1201,7 +1257,7 @@ class SQLBaseStore(object): table="cache_invalidation_stream", values={ "stream_id": stream_id, - "cache_func": cache_func.__name__, + "cache_func": cache_name, "keys": list(keys), "invalidation_ts": self.clock.time_msec(), } diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 81b250480d..06db9e56e6 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -979,30 +979,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore if ev_type == EventTypes.Member ) - for member in members_changed: - self._invalidate_cache_and_stream( - txn, self.get_rooms_for_user_with_stream_ordering, (member,) - ) - - for host in set(get_domain_from_id(u) for u in members_changed): - self._invalidate_cache_and_stream( - txn, self.is_host_joined, (room_id, host) - ) - self._invalidate_cache_and_stream( - txn, self.was_host_joined, (room_id, host) - ) - - self._invalidate_cache_and_stream( - txn, self.get_users_in_room, (room_id,) - ) - - self._invalidate_cache_and_stream( - txn, self.get_room_summary, (room_id,) - ) - - self._invalidate_cache_and_stream( - txn, self.get_current_state_ids, (room_id,) - ) + self._invalidate_state_caches_and_stream(txn, room_id, members_changed) def _update_forward_extremities_txn(self, txn, new_forward_extremities, max_stream_order): |