diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index bc1bc97e19..1a7d4c5199 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -276,13 +276,6 @@ class EventsStore(SQLBaseStore):
events_and_contexts, stream_orderings
):
event.internal_metadata.stream_ordering = stream
- # Assign a state group_id in case a new id is needed for
- # this context. In theory we only need to assign this
- # for contexts that have current_state and aren't outliers
- # but that make the code more complicated. Assigning an ID
- # per event only causes the state_group_ids to grow as fast
- # as the stream_ordering so in practise shouldn't be a problem.
- context.new_state_group_id = self._state_groups_id_gen.get_next()
chunks = [
events_and_contexts[x:x + 100]
@@ -309,7 +302,6 @@ class EventsStore(SQLBaseStore):
try:
with self._stream_id_gen.get_next() as stream_ordering:
event.internal_metadata.stream_ordering = stream_ordering
- context.new_state_group_id = self._state_groups_id_gen.get_next()
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
@@ -523,7 +515,7 @@ class EventsStore(SQLBaseStore):
# Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
- state_group_id = context.state_group or context.new_state_group_id
+ state_group_id = context.state_group
self._simple_insert_txn(
txn,
table="ex_outlier_stream",
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0442353287..56bfdc0b55 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -83,6 +83,14 @@ class StateStore(SQLBaseStore):
for group, event_id_map in group_to_ids.items()
})
+ def _have_persisted_state_group_txn(self, txn, state_group):
+ txn.execute(
+ "SELECT count(*) FROM state_groups_state WHERE state_group = ?",
+ (state_group,)
+ )
+ row = txn.fetchone()
+ return row and row[0]
+
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
@@ -92,8 +100,10 @@ class StateStore(SQLBaseStore):
if context.current_state_ids is None:
continue
- if context.state_group is not None:
- state_groups[event.event_id] = context.state_group
+ state_groups[event.event_id] = context.state_group
+
+ if self._have_persisted_state_group_txn(txn, context.state_group):
+ logger.info("Already persisted state_group: %r", context.state_group)
continue
state_event_ids = dict(context.current_state_ids)
@@ -101,13 +111,11 @@ class StateStore(SQLBaseStore):
if event.is_state():
state_event_ids[(event.type, event.state_key)] = event.event_id
- state_group = context.new_state_group_id
-
self._simple_insert_txn(
txn,
table="state_groups",
values={
- "id": state_group,
+ "id": context.state_group,
"room_id": event.room_id,
"event_id": event.event_id,
},
@@ -118,7 +126,7 @@ class StateStore(SQLBaseStore):
table="state_groups_state",
values=[
{
- "state_group": state_group,
+ "state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
@@ -127,7 +135,6 @@ class StateStore(SQLBaseStore):
for key, state_id in state_event_ids.items()
],
)
- state_groups[event.event_id] = state_group
self._simple_insert_many_txn(
txn,
@@ -526,3 +533,6 @@ class StateStore(SQLBaseStore):
return self.runInteraction(
"get_all_new_state_groups", get_all_new_state_groups_txn
)
+
+ def get_next_state_group(self):
+ return self._state_groups_id_gen.get_next()
|