diff --git a/changelog.d/6748.misc b/changelog.d/6748.misc
new file mode 100644
index 0000000000..de320d4cd9
--- /dev/null
+++ b/changelog.d/6748.misc
@@ -0,0 +1 @@
+Propagate cache invalidates from workers to other workers.
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 131e5acb09..bc1482a9bb 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -459,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
async def on_INVALIDATE_CACHE(self, cmd):
- self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+ await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.streamer.on_remote_server_up(cmd.data)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 6ebf944f66..ce60ae2e07 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,7 +17,7 @@
import logging
import random
-from typing import List
+from typing import Any, List
from six import itervalues
@@ -271,11 +271,14 @@ class ReplicationStreamer(object):
self.notifier.on_new_replication_data()
@measure_func("repl.on_invalidate_cache")
- def on_invalidate_cache(self, cache_func, keys):
+ async def on_invalidate_cache(self, cache_func: str, keys: List[Any]):
"""The client has asked us to invalidate a cache
"""
invalidate_cache_counter.inc()
- getattr(self.store, cache_func).invalidate(tuple(keys))
+
+ # We invalidate the cache locally, but then also stream that to other
+ # workers.
+ await self.store.invalidate_cache_and_stream(cache_func, tuple(keys))
@measure_func("repl.on_user_ip")
async def on_user_ip(
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index afa2b41c98..d4c44dcc75 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -16,7 +16,7 @@
import itertools
import logging
-from typing import Any, Iterable, Optional
+from typing import Any, Iterable, Optional, Tuple
from twisted.internet import defer
@@ -33,6 +33,26 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationStore(SQLBaseStore):
+ async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
+ """Invalidates the cache and adds it to the cache stream so slaves
+ will know to invalidate their caches.
+
+ This should only be used to invalidate caches where slaves won't
+ otherwise know from other replication streams that the cache should
+ be invalidated.
+ """
+ cache_func = getattr(self, cache_name, None)
+ if not cache_func:
+ return
+
+ cache_func.invalidate(keys)
+ await self.runInteraction(
+ "invalidate_cache_and_stream",
+ self._send_invalidation_to_replication,
+ cache_func.__name__,
+ keys,
+ )
+
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
|