diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index f855903c39..f32cbb2dec 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,7 @@
# limitations under the License.
import collections.abc
import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Set, Tuple
import attr
@@ -24,6 +24,8 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.opentracing import trace
+from synapse.replication.tcp.streams import UnPartialStatedEventStream
+from synapse.replication.tcp.streams.partial_state import UnPartialStatedEventStreamRow
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -82,6 +84,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
super().__init__(database, db_conn, hs)
self._instance_name: str = hs.get_instance_name()
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
+ if stream_name == UnPartialStatedEventStream.NAME:
+ for row in rows:
+ assert isinstance(row, UnPartialStatedEventStreamRow)
+ self._get_state_group_for_event.invalidate((row.event_id,))
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
+
async def get_room_version(self, room_id: str) -> RoomVersion:
"""Get the room_version of a given room
Raises:
|