diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 24 |
1 files changed, 17 insertions, 7 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 0442353287..56bfdc0b55 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -83,6 +83,14 @@ class StateStore(SQLBaseStore): for group, event_id_map in group_to_ids.items() }) + def _have_persisted_state_group_txn(self, txn, state_group): + txn.execute( + "SELECT count(*) FROM state_groups_state WHERE state_group = ?", + (state_group,) + ) + row = txn.fetchone() + return row and row[0] + def _store_mult_state_groups_txn(self, txn, events_and_contexts): state_groups = {} for event, context in events_and_contexts: @@ -92,8 +100,10 @@ class StateStore(SQLBaseStore): if context.current_state_ids is None: continue - if context.state_group is not None: - state_groups[event.event_id] = context.state_group + state_groups[event.event_id] = context.state_group + + if self._have_persisted_state_group_txn(txn, context.state_group): + logger.info("Already persisted state_group: %r", context.state_group) continue state_event_ids = dict(context.current_state_ids) @@ -101,13 +111,11 @@ class StateStore(SQLBaseStore): if event.is_state(): state_event_ids[(event.type, event.state_key)] = event.event_id - state_group = context.new_state_group_id - self._simple_insert_txn( txn, table="state_groups", values={ - "id": state_group, + "id": context.state_group, "room_id": event.room_id, "event_id": event.event_id, }, @@ -118,7 +126,7 @@ class StateStore(SQLBaseStore): table="state_groups_state", values=[ { - "state_group": state_group, + "state_group": context.state_group, "room_id": event.room_id, "type": key[0], "state_key": key[1], @@ -127,7 +135,6 @@ class StateStore(SQLBaseStore): for key, state_id in state_event_ids.items() ], ) - state_groups[event.event_id] = state_group self._simple_insert_many_txn( txn, @@ -526,3 +533,6 @@ class StateStore(SQLBaseStore): return self.runInteraction( "get_all_new_state_groups", get_all_new_state_groups_txn ) + + def get_next_state_group(self): + return self._state_groups_id_gen.get_next() |