diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 7967011afd..8df80664a2 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -57,7 +57,7 @@ class SQLBaseStore(metaclass=ABCMeta):
pass
def _invalidate_state_caches(
- self, room_id: str, members_changed: Iterable[str]
+ self, room_id: str, members_changed: Collection[str]
) -> None:
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
@@ -66,11 +66,16 @@ class SQLBaseStore(metaclass=ABCMeta):
room_id: Room where state changed
members_changed: The user_ids of members that have changed
"""
+ # If there were any membership changes, purge the appropriate caches.
for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
+ if members_changed:
+ self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
+ self._attempt_to_invalidate_cache(
+ "get_users_in_room_with_profiles", (room_id,)
+ )
- self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
- self._attempt_to_invalidate_cache("get_users_in_room_with_profiles", (room_id,))
+ # Purge other caches based on room state.
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 0024348067..c428dd5596 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -15,7 +15,7 @@
import itertools
import logging
-from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams import BackfillStream, CachesStream
@@ -25,7 +25,11 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow,
)
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
@@ -236,7 +240,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate_all)
self._send_invalidation_to_replication(txn, cache_func.__name__, None)
- def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
+ def _invalidate_state_caches_and_stream(
+ self, txn: LoggingTransaction, room_id: str, members_changed: Collection[str]
+ ) -> None:
"""Special case invalidation of caches based on current state.
We special case this so that we can batch the cache invalidations into a
@@ -244,8 +250,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
Args:
txn
- room_id (str): Room where state changed
- members_changed (iterable[str]): The user_ids of members that have changed
+ room_id: Room where state changed
+ members_changed: The user_ids of members that have changed
"""
txn.call_after(self._invalidate_state_caches, room_id, members_changed)
|