diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index c801a93b5b..f32cbb2dec 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,7 @@
# limitations under the License.
import collections.abc
import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Set, Tuple
import attr
@@ -24,6 +24,8 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.opentracing import trace
+from synapse.replication.tcp.streams import UnPartialStatedEventStream
+from synapse.replication.tcp.streams.partial_state import UnPartialStatedEventStreamRow
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -80,6 +82,21 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
+ self._instance_name: str = hs.get_instance_name()
+
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
+ if stream_name == UnPartialStatedEventStream.NAME:
+ for row in rows:
+ assert isinstance(row, UnPartialStatedEventStreamRow)
+ self._get_state_group_for_event.invalidate((row.event_id,))
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
async def get_room_version(self, room_id: str) -> RoomVersion:
"""Get the room_version of a given room
@@ -404,18 +421,21 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
context: EventContext,
) -> None:
"""Update the state group for a partial state event"""
- await self.db_pool.runInteraction(
- "update_state_for_partial_state_event",
- self._update_state_for_partial_state_event_txn,
- event,
- context,
- )
+ async with self._un_partial_stated_events_stream_id_gen.get_next() as un_partial_state_event_stream_id:
+ await self.db_pool.runInteraction(
+ "update_state_for_partial_state_event",
+ self._update_state_for_partial_state_event_txn,
+ event,
+ context,
+ un_partial_state_event_stream_id,
+ )
def _update_state_for_partial_state_event_txn(
self,
txn: LoggingTransaction,
event: EventBase,
context: EventContext,
+ un_partial_state_event_stream_id: int,
) -> None:
# we shouldn't have any outliers here
assert not event.internal_metadata.is_outlier()
@@ -436,7 +456,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# the event may now be rejected where it was not before, or vice versa,
# in which case we need to update the rejected flags.
- if bool(context.rejected) != (event.rejected_reason is not None):
+ rejection_status_changed = bool(context.rejected) != (
+ event.rejected_reason is not None
+ )
+ if rejection_status_changed:
self.mark_event_rejected_txn(txn, event.event_id, context.rejected)
self.db_pool.simple_delete_one_txn(
@@ -445,8 +468,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
keyvalues={"event_id": event.event_id},
)
- # TODO(faster_joins): need to do something about workers here
- # https://github.com/matrix-org/synapse/issues/12994
txn.call_after(self.is_partial_state_event.invalidate, (event.event_id,))
txn.call_after(
self._get_state_group_for_event.prefill,
@@ -454,6 +475,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
state_group,
)
+ self.db_pool.simple_insert_txn(
+ txn,
+ "un_partial_stated_event_stream",
+ {
+ "stream_id": un_partial_state_event_stream_id,
+ "instance_name": self._instance_name,
+ "event_id": event.event_id,
+ "rejection_status_changed": rejection_status_changed,
+ },
+ )
+
class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|