diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index d4c44dcc75..4dc5da3fe8 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -32,7 +32,29 @@ logger = logging.getLogger(__name__)
CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
-class CacheInvalidationStore(SQLBaseStore):
+class CacheInvalidationWorkerStore(SQLBaseStore):
+ def get_all_updated_caches(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_updated_caches_txn(txn):
+ # We purposefully don't bound by the current token, as we want to
+ # send across cache invalidations as quickly as possible. Cache
+ # invalidations are idempotent, so duplicates are fine.
+ sql = (
+ "SELECT stream_id, cache_func, keys, invalidation_ts"
+ " FROM cache_invalidation_stream"
+ " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, limit))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_updated_caches", get_all_updated_caches_txn
+ )
+
+
+class CacheInvalidationStore(CacheInvalidationWorkerStore):
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore):
},
)
- def get_all_updated_caches(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_updated_caches_txn(txn):
- # We purposefully don't bound by the current token, as we want to
- # send across cache invalidations as quickly as possible. Cache
- # invalidations are idempotent, so duplicates are fine.
- sql = (
- "SELECT stream_id, cache_func, keys, invalidation_ts"
- " FROM cache_invalidation_stream"
- " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, limit))
- return txn.fetchall()
-
- return self.db.runInteraction(
- "get_all_updated_caches", get_all_updated_caches_txn
- )
-
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
|