diff --git a/changelog.d/17291.misc b/changelog.d/17291.misc
deleted file mode 100644
index b1f89a324d..0000000000
--- a/changelog.d/17291.misc
+++ /dev/null
@@ -1 +0,0 @@
-Do not block event sending/receiving while calulating large event auth chains.
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index d0e015bf19..84699a2ee1 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -617,17 +617,6 @@ class EventsPersistenceStorageController:
room_id, chunk
)
- with Measure(self._clock, "calculate_chain_cover_index_for_events"):
- # We now calculate chain ID/sequence numbers for any state events we're
- # persisting. We ignore out of band memberships as we're not in the room
- # and won't have their auth chain (we'll fix it up later if we join the
- # room).
- #
- # See: docs/auth_chain_difference_algorithm.md
- new_event_links = await self.persist_events_store.calculate_chain_cover_index_for_events(
- room_id, [e for e, _ in chunk]
- )
-
await self.persist_events_store._persist_events_and_state_updates(
room_id,
chunk,
@@ -635,7 +624,6 @@ class EventsPersistenceStorageController:
new_forward_extremities=new_forward_extremities,
use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled,
- new_event_links=new_event_links,
)
return replaced_events
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index c6df13c064..66428e6c8e 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -34,6 +34,7 @@ from typing import (
Optional,
Set,
Tuple,
+ Union,
cast,
)
@@ -99,23 +100,6 @@ class DeltaState:
return not self.to_delete and not self.to_insert and not self.no_longer_in_room
-@attr.s(slots=True, auto_attribs=True)
-class NewEventChainLinks:
- """Information about new auth chain links that need to be added to the DB.
-
- Attributes:
- chain_id, sequence_number: the IDs corresponding to the event being
- inserted, and the starting point of the links
- links: Lists the links that need to be added, 2-tuple of the chain
- ID/sequence number of the end point of the link.
- """
-
- chain_id: int
- sequence_number: int
-
- links: List[Tuple[int, int]] = attr.Factory(list)
-
-
class PersistEventsStore:
"""Contains all the functions for writing events to the database.
@@ -164,7 +148,6 @@ class PersistEventsStore:
*,
state_delta_for_room: Optional[DeltaState],
new_forward_extremities: Optional[Set[str]],
- new_event_links: Dict[str, NewEventChainLinks],
use_negative_stream_ordering: bool = False,
inhibit_local_membership_updates: bool = False,
) -> None:
@@ -234,7 +217,6 @@ class PersistEventsStore:
inhibit_local_membership_updates=inhibit_local_membership_updates,
state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities,
- new_event_links=new_event_links,
)
persist_event_counter.inc(len(events_and_contexts))
@@ -261,87 +243,6 @@ class PersistEventsStore:
(room_id,), frozenset(new_forward_extremities)
)
- async def calculate_chain_cover_index_for_events(
- self, room_id: str, events: Collection[EventBase]
- ) -> Dict[str, NewEventChainLinks]:
- # Filter to state events, and ensure there are no duplicates.
- state_events = []
- seen_events = set()
- for event in events:
- if not event.is_state() or event.event_id in seen_events:
- continue
-
- state_events.append(event)
- seen_events.add(event.event_id)
-
- if not state_events:
- return {}
-
- return await self.db_pool.runInteraction(
- "_calculate_chain_cover_index_for_events",
- self.calculate_chain_cover_index_for_events_txn,
- room_id,
- state_events,
- )
-
- def calculate_chain_cover_index_for_events_txn(
- self, txn: LoggingTransaction, room_id: str, state_events: Collection[EventBase]
- ) -> Dict[str, NewEventChainLinks]:
- # We now calculate chain ID/sequence numbers for any state events we're
- # persisting. We ignore out of band memberships as we're not in the room
- # and won't have their auth chain (we'll fix it up later if we join the
- # room).
- #
- # See: docs/auth_chain_difference_algorithm.md
-
- # We ignore legacy rooms that we aren't filling the chain cover index
- # for.
- row = self.db_pool.simple_select_one_txn(
- txn,
- table="rooms",
- keyvalues={"room_id": room_id},
- retcols=("room_id", "has_auth_chain_index"),
- allow_none=True,
- )
- if row is None:
- return {}
-
- # Filter out already persisted events.
- rows = self.db_pool.simple_select_many_txn(
- txn,
- table="events",
- column="event_id",
- iterable=[e.event_id for e in state_events],
- keyvalues={},
- retcols=("event_id",),
- )
- already_persisted_events = {event_id for event_id, in rows}
- state_events = [
- event
- for event in state_events
- if event.event_id in already_persisted_events
- ]
-
- if not state_events:
- return {}
-
- # We need to know the type/state_key and auth events of the events we're
- # calculating chain IDs for. We don't rely on having the full Event
- # instances as we'll potentially be pulling more events from the DB and
- # we don't need the overhead of fetching/parsing the full event JSON.
- event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events}
- event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events}
- event_to_room_id = {e.event_id: e.room_id for e in state_events}
-
- return self._calculate_chain_cover_index(
- txn,
- self.db_pool,
- self.store.event_chain_id_gen,
- event_to_room_id,
- event_to_types,
- event_to_auth_chain,
- )
-
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
"""Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events.
@@ -457,7 +358,6 @@ class PersistEventsStore:
inhibit_local_membership_updates: bool,
state_delta_for_room: Optional[DeltaState],
new_forward_extremities: Optional[Set[str]],
- new_event_links: Dict[str, NewEventChainLinks],
) -> None:
"""Insert some number of room events into the necessary database tables.
@@ -566,9 +466,7 @@ class PersistEventsStore:
# Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts)
- self._persist_event_auth_chain_txn(
- txn, [e for e, _ in events_and_contexts], new_event_links
- )
+ self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
# _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list.
@@ -598,7 +496,6 @@ class PersistEventsStore:
self,
txn: LoggingTransaction,
events: List[EventBase],
- new_event_links: Dict[str, NewEventChainLinks],
) -> None:
# We only care about state events, so this if there are no state events.
if not any(e.is_state() for e in events):
@@ -622,40 +519,62 @@ class PersistEventsStore:
],
)
- if new_event_links:
- self._persist_chain_cover_index(txn, self.db_pool, new_event_links)
+ # We now calculate chain ID/sequence numbers for any state events we're
+ # persisting. We ignore out of band memberships as we're not in the room
+ # and won't have their auth chain (we'll fix it up later if we join the
+ # room).
+ #
+ # See: docs/auth_chain_difference_algorithm.md
- @classmethod
- def _add_chain_cover_index(
- cls,
- txn: LoggingTransaction,
- db_pool: DatabasePool,
- event_chain_id_gen: SequenceGenerator,
- event_to_room_id: Dict[str, str],
- event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, StrCollection],
- ) -> None:
- """Calculate and persist the chain cover index for the given events.
+ # We ignore legacy rooms that we aren't filling the chain cover index
+ # for.
+ rows = cast(
+ List[Tuple[str, Optional[Union[int, bool]]]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="rooms",
+ column="room_id",
+ iterable={event.room_id for event in events if event.is_state()},
+ keyvalues={},
+ retcols=("room_id", "has_auth_chain_index"),
+ ),
+ )
+ rooms_using_chain_index = {
+ room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
+ }
- Args:
- event_to_room_id: Event ID to the room ID of the event
- event_to_types: Event ID to type and state_key of the event
- event_to_auth_chain: Event ID to list of auth event IDs of the
- event (events with no auth events can be excluded).
- """
+ state_events = {
+ event.event_id: event
+ for event in events
+ if event.is_state() and event.room_id in rooms_using_chain_index
+ }
+
+ if not state_events:
+ return
+
+ # We need to know the type/state_key and auth events of the events we're
+ # calculating chain IDs for. We don't rely on having the full Event
+ # instances as we'll potentially be pulling more events from the DB and
+ # we don't need the overhead of fetching/parsing the full event JSON.
+ event_to_types = {
+ e.event_id: (e.type, e.state_key) for e in state_events.values()
+ }
+ event_to_auth_chain = {
+ e.event_id: e.auth_event_ids() for e in state_events.values()
+ }
+ event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
- new_event_links = cls._calculate_chain_cover_index(
+ self._add_chain_cover_index(
txn,
- db_pool,
- event_chain_id_gen,
+ self.db_pool,
+ self.store.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
)
- cls._persist_chain_cover_index(txn, db_pool, new_event_links)
@classmethod
- def _calculate_chain_cover_index(
+ def _add_chain_cover_index(
cls,
txn: LoggingTransaction,
db_pool: DatabasePool,
@@ -663,7 +582,7 @@ class PersistEventsStore:
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, StrCollection],
- ) -> Dict[str, NewEventChainLinks]:
+ ) -> None:
"""Calculate the chain cover index for the given events.
Args:
@@ -671,10 +590,6 @@ class PersistEventsStore:
event_to_types: Event ID to type and state_key of the event
event_to_auth_chain: Event ID to list of auth event IDs of the
event (events with no auth events can be excluded).
-
- Returns:
- A mapping with any new auth chain links we need to add, keyed by
- event ID.
"""
# Map from event ID to chain ID/sequence number.
@@ -793,11 +708,11 @@ class PersistEventsStore:
room_id = event_to_room_id.get(event_id)
if room_id:
e_type, state_key = event_to_types[event_id]
- db_pool.simple_upsert_txn(
+ db_pool.simple_insert_txn(
txn,
table="event_auth_chain_to_calculate",
- keyvalues={"event_id": event_id},
values={
+ "event_id": event_id,
"room_id": room_id,
"type": e_type,
"state_key": state_key,
@@ -809,7 +724,7 @@ class PersistEventsStore:
break
if not events_to_calc_chain_id_for:
- return {}
+ return
# Allocate chain ID/sequence numbers to each new event.
new_chain_tuples = cls._allocate_chain_ids(
@@ -824,10 +739,23 @@ class PersistEventsStore:
)
chain_map.update(new_chain_tuples)
- to_return = {
- event_id: NewEventChainLinks(chain_id, sequence_number)
- for event_id, (chain_id, sequence_number) in new_chain_tuples.items()
- }
+ db_pool.simple_insert_many_txn(
+ txn,
+ table="event_auth_chains",
+ keys=("event_id", "chain_id", "sequence_number"),
+ values=[
+ (event_id, c_id, seq)
+ for event_id, (c_id, seq) in new_chain_tuples.items()
+ ],
+ )
+
+ db_pool.simple_delete_many_txn(
+ txn,
+ table="event_auth_chain_to_calculate",
+ keyvalues={},
+ column="event_id",
+ values=new_chain_tuples,
+ )
# Now we need to calculate any new links between chains caused by
# the new events.
@@ -897,38 +825,10 @@ class PersistEventsStore:
auth_chain_id, auth_sequence_number = chain_map[auth_id]
# Step 2a, add link between the event and auth event
- to_return[event_id].links.append((auth_chain_id, auth_sequence_number))
chain_links.add_link(
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
)
- return to_return
-
- @classmethod
- def _persist_chain_cover_index(
- cls,
- txn: LoggingTransaction,
- db_pool: DatabasePool,
- new_event_links: Dict[str, NewEventChainLinks],
- ) -> None:
- db_pool.simple_insert_many_txn(
- txn,
- table="event_auth_chains",
- keys=("event_id", "chain_id", "sequence_number"),
- values=[
- (event_id, new_links.chain_id, new_links.sequence_number)
- for event_id, new_links in new_event_links.items()
- ],
- )
-
- db_pool.simple_delete_many_txn(
- txn,
- table="event_auth_chain_to_calculate",
- keyvalues={},
- column="event_id",
- values=new_event_links,
- )
-
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chain_links",
@@ -938,16 +838,7 @@ class PersistEventsStore:
"target_chain_id",
"target_sequence_number",
),
- values=[
- (
- new_links.chain_id,
- new_links.sequence_number,
- target_chain_id,
- target_sequence_number,
- )
- for new_links in new_event_links.values()
- for (target_chain_id, target_sequence_number) in new_links.links
- ],
+ values=list(chain_links.get_additions()),
)
@staticmethod
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index c4e216c308..81feb3ec29 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -447,14 +447,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
)
# Actually call the function that calculates the auth chain stuff.
- new_event_links = (
- persist_events_store.calculate_chain_cover_index_for_events_txn(
- txn, events[0].room_id, [e for e in events if e.is_state()]
- )
- )
- persist_events_store._persist_event_auth_chain_txn(
- txn, events, new_event_links
- )
+ persist_events_store._persist_event_auth_chain_txn(txn, events)
self.get_success(
persist_events_store.db_pool.runInteraction(
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 1832a23714..0a6253e22c 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -365,19 +365,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
},
)
- events = [
- cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
- for event_id in AUTH_GRAPH
- ]
- new_event_links = (
- self.persist_events.calculate_chain_cover_index_for_events_txn(
- txn, room_id, [e for e in events if e.is_state()]
- )
- )
self.persist_events._persist_event_auth_chain_txn(
txn,
- events,
- new_event_links,
+ [
+ cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
+ for event_id in AUTH_GRAPH
+ ],
)
self.get_success(
@@ -635,20 +628,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
# Insert all events apart from 'B'
- events = [
- cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
- for event_id in auth_graph
- if event_id != "b"
- ]
- new_event_links = (
- self.persist_events.calculate_chain_cover_index_for_events_txn(
- txn, room_id, [e for e in events if e.is_state()]
- )
- )
self.persist_events._persist_event_auth_chain_txn(
txn,
- events,
- new_event_links,
+ [
+ cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
+ for event_id in auth_graph
+ if event_id != "b"
+ ],
)
# Now we insert the event 'B' without a chain cover, by temporarily
@@ -661,14 +647,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
updatevalues={"has_auth_chain_index": False},
)
- events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))]
- new_event_links = (
- self.persist_events.calculate_chain_cover_index_for_events_txn(
- txn, room_id, [e for e in events if e.is_state()]
- )
- )
self.persist_events._persist_event_auth_chain_txn(
- txn, events, new_event_links
+ txn,
+ [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
)
self.store.db_pool.simple_update_txn(
|