diff --git a/changelog.d/17338.misc b/changelog.d/17338.misc
new file mode 100644
index 0000000000..1a81bdef85
--- /dev/null
+++ b/changelog.d/17338.misc
@@ -0,0 +1 @@
+Do not block event sending/receiving while calculating large event auth chains.
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index 84699a2ee1..d0e015bf19 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -617,6 +617,17 @@ class EventsPersistenceStorageController:
room_id, chunk
)
+ with Measure(self._clock, "calculate_chain_cover_index_for_events"):
+ # We now calculate chain ID/sequence numbers for any state events we're
+ # persisting. We ignore out of band memberships as we're not in the room
+ # and won't have their auth chain (we'll fix it up later if we join the
+ # room).
+ #
+ # See: docs/auth_chain_difference_algorithm.md
+ new_event_links = await self.persist_events_store.calculate_chain_cover_index_for_events(
+ room_id, [e for e, _ in chunk]
+ )
+
await self.persist_events_store._persist_events_and_state_updates(
room_id,
chunk,
@@ -624,6 +635,7 @@ class EventsPersistenceStorageController:
new_forward_extremities=new_forward_extremities,
use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled,
+ new_event_links=new_event_links,
)
return replaced_events
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index fb132ef090..24abab4a23 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -148,6 +148,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
500000, "_event_auth_cache", size_callback=len
)
+ # Flag used by unit tests to disable fallback when there is no chain cover
+ # index.
+ self.tests_allow_no_chain_cover_index = True
+
self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
if isinstance(self.database_engine, PostgresEngine):
@@ -220,8 +224,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
except _NoChainCoverIndex:
# For whatever reason we don't actually have a chain cover index
- # for the events in question, so we fall back to the old method.
- pass
+ # for the events in question, so we fall back to the old method
+ # (except in tests)
+ if not self.tests_allow_no_chain_cover_index:
+ raise
return await self.db_pool.runInteraction(
"get_auth_chain_ids",
@@ -271,7 +277,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
if events_missing_chain_info:
# This can happen due to e.g. downgrade/upgrade of the server. We
# raise an exception and fall back to the previous algorithm.
- logger.info(
+ logger.error(
"Unexpectedly found that events don't have chain IDs in room %s: %s",
room_id,
events_missing_chain_info,
@@ -482,8 +488,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
except _NoChainCoverIndex:
# For whatever reason we don't actually have a chain cover index
- # for the events in question, so we fall back to the old method.
- pass
+ # for the events in question, so we fall back to the old method
+ # (except in tests)
+ if not self.tests_allow_no_chain_cover_index:
+ raise
return await self.db_pool.runInteraction(
"get_auth_chain_difference",
@@ -710,7 +718,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
if events_missing_chain_info - event_to_auth_ids.keys():
# Uh oh, we somehow haven't correctly done the chain cover index,
# bail and fall back to the old method.
- logger.info(
+ logger.error(
"Unexpectedly found that events don't have chain IDs in room %s: %s",
room_id,
events_missing_chain_info - event_to_auth_ids.keys(),
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 66428e6c8e..1f7acdb859 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -34,7 +34,6 @@ from typing import (
Optional,
Set,
Tuple,
- Union,
cast,
)
@@ -100,6 +99,23 @@ class DeltaState:
return not self.to_delete and not self.to_insert and not self.no_longer_in_room
+@attr.s(slots=True, auto_attribs=True)
+class NewEventChainLinks:
+ """Information about new auth chain links that need to be added to the DB.
+
+ Attributes:
+ chain_id, sequence_number: the IDs corresponding to the event being
+ inserted, and the starting point of the links
+ links: Lists the links that need to be added, 2-tuple of the chain
+ ID/sequence number of the end point of the link.
+ """
+
+ chain_id: int
+ sequence_number: int
+
+ links: List[Tuple[int, int]] = attr.Factory(list)
+
+
class PersistEventsStore:
"""Contains all the functions for writing events to the database.
@@ -148,6 +164,7 @@ class PersistEventsStore:
*,
state_delta_for_room: Optional[DeltaState],
new_forward_extremities: Optional[Set[str]],
+ new_event_links: Dict[str, NewEventChainLinks],
use_negative_stream_ordering: bool = False,
inhibit_local_membership_updates: bool = False,
) -> None:
@@ -217,6 +234,7 @@ class PersistEventsStore:
inhibit_local_membership_updates=inhibit_local_membership_updates,
state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities,
+ new_event_links=new_event_links,
)
persist_event_counter.inc(len(events_and_contexts))
@@ -243,6 +261,87 @@ class PersistEventsStore:
(room_id,), frozenset(new_forward_extremities)
)
+ async def calculate_chain_cover_index_for_events(
+ self, room_id: str, events: Collection[EventBase]
+ ) -> Dict[str, NewEventChainLinks]:
+ # Filter to state events, and ensure there are no duplicates.
+ state_events = []
+ seen_events = set()
+ for event in events:
+ if not event.is_state() or event.event_id in seen_events:
+ continue
+
+ state_events.append(event)
+ seen_events.add(event.event_id)
+
+ if not state_events:
+ return {}
+
+ return await self.db_pool.runInteraction(
+ "_calculate_chain_cover_index_for_events",
+ self.calculate_chain_cover_index_for_events_txn,
+ room_id,
+ state_events,
+ )
+
+ def calculate_chain_cover_index_for_events_txn(
+ self, txn: LoggingTransaction, room_id: str, state_events: Collection[EventBase]
+ ) -> Dict[str, NewEventChainLinks]:
+ # We now calculate chain ID/sequence numbers for any state events we're
+ # persisting. We ignore out of band memberships as we're not in the room
+ # and won't have their auth chain (we'll fix it up later if we join the
+ # room).
+ #
+ # See: docs/auth_chain_difference_algorithm.md
+
+ # We ignore legacy rooms that we aren't filling the chain cover index
+ # for.
+ row = self.db_pool.simple_select_one_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcols=("room_id", "has_auth_chain_index"),
+ allow_none=True,
+ )
+ if row is None or row[1] is False:
+ return {}
+
+ # Filter out events that we've already calculated.
+ rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth_chains",
+ column="event_id",
+ iterable=[e.event_id for e in state_events],
+ keyvalues={},
+ retcols=("event_id",),
+ )
+ already_persisted_events = {event_id for event_id, in rows}
+ state_events = [
+ event
+ for event in state_events
+ if event.event_id not in already_persisted_events
+ ]
+
+ if not state_events:
+ return {}
+
+ # We need to know the type/state_key and auth events of the events we're
+ # calculating chain IDs for. We don't rely on having the full Event
+ # instances as we'll potentially be pulling more events from the DB and
+ # we don't need the overhead of fetching/parsing the full event JSON.
+ event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events}
+ event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events}
+ event_to_room_id = {e.event_id: e.room_id for e in state_events}
+
+ return self._calculate_chain_cover_index(
+ txn,
+ self.db_pool,
+ self.store.event_chain_id_gen,
+ event_to_room_id,
+ event_to_types,
+ event_to_auth_chain,
+ )
+
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
"""Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events.
@@ -358,6 +457,7 @@ class PersistEventsStore:
inhibit_local_membership_updates: bool,
state_delta_for_room: Optional[DeltaState],
new_forward_extremities: Optional[Set[str]],
+ new_event_links: Dict[str, NewEventChainLinks],
) -> None:
"""Insert some number of room events into the necessary database tables.
@@ -466,7 +566,9 @@ class PersistEventsStore:
# Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts)
- self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
+ self._persist_event_auth_chain_txn(
+ txn, [e for e, _ in events_and_contexts], new_event_links
+ )
# _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list.
@@ -496,7 +598,11 @@ class PersistEventsStore:
self,
txn: LoggingTransaction,
events: List[EventBase],
+ new_event_links: Dict[str, NewEventChainLinks],
) -> None:
+ if new_event_links:
+ self._persist_chain_cover_index(txn, self.db_pool, new_event_links)
+
# We only care about state events, so this if there are no state events.
if not any(e.is_state() for e in events):
return
@@ -519,62 +625,37 @@ class PersistEventsStore:
],
)
- # We now calculate chain ID/sequence numbers for any state events we're
- # persisting. We ignore out of band memberships as we're not in the room
- # and won't have their auth chain (we'll fix it up later if we join the
- # room).
- #
- # See: docs/auth_chain_difference_algorithm.md
-
- # We ignore legacy rooms that we aren't filling the chain cover index
- # for.
- rows = cast(
- List[Tuple[str, Optional[Union[int, bool]]]],
- self.db_pool.simple_select_many_txn(
- txn,
- table="rooms",
- column="room_id",
- iterable={event.room_id for event in events if event.is_state()},
- keyvalues={},
- retcols=("room_id", "has_auth_chain_index"),
- ),
- )
- rooms_using_chain_index = {
- room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
- }
-
- state_events = {
- event.event_id: event
- for event in events
- if event.is_state() and event.room_id in rooms_using_chain_index
- }
-
- if not state_events:
- return
+ @classmethod
+ def _add_chain_cover_index(
+ cls,
+ txn: LoggingTransaction,
+ db_pool: DatabasePool,
+ event_chain_id_gen: SequenceGenerator,
+ event_to_room_id: Dict[str, str],
+ event_to_types: Dict[str, Tuple[str, str]],
+ event_to_auth_chain: Dict[str, StrCollection],
+ ) -> None:
+ """Calculate and persist the chain cover index for the given events.
- # We need to know the type/state_key and auth events of the events we're
- # calculating chain IDs for. We don't rely on having the full Event
- # instances as we'll potentially be pulling more events from the DB and
- # we don't need the overhead of fetching/parsing the full event JSON.
- event_to_types = {
- e.event_id: (e.type, e.state_key) for e in state_events.values()
- }
- event_to_auth_chain = {
- e.event_id: e.auth_event_ids() for e in state_events.values()
- }
- event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
+ Args:
+ event_to_room_id: Event ID to the room ID of the event
+ event_to_types: Event ID to type and state_key of the event
+ event_to_auth_chain: Event ID to list of auth event IDs of the
+ event (events with no auth events can be excluded).
+ """
- self._add_chain_cover_index(
+ new_event_links = cls._calculate_chain_cover_index(
txn,
- self.db_pool,
- self.store.event_chain_id_gen,
+ db_pool,
+ event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
)
+ cls._persist_chain_cover_index(txn, db_pool, new_event_links)
@classmethod
- def _add_chain_cover_index(
+ def _calculate_chain_cover_index(
cls,
txn: LoggingTransaction,
db_pool: DatabasePool,
@@ -582,7 +663,7 @@ class PersistEventsStore:
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, StrCollection],
- ) -> None:
+ ) -> Dict[str, NewEventChainLinks]:
"""Calculate the chain cover index for the given events.
Args:
@@ -590,6 +671,10 @@ class PersistEventsStore:
event_to_types: Event ID to type and state_key of the event
event_to_auth_chain: Event ID to list of auth event IDs of the
event (events with no auth events can be excluded).
+
+ Returns:
+ A mapping with any new auth chain links we need to add, keyed by
+ event ID.
"""
# Map from event ID to chain ID/sequence number.
@@ -708,11 +793,11 @@ class PersistEventsStore:
room_id = event_to_room_id.get(event_id)
if room_id:
e_type, state_key = event_to_types[event_id]
- db_pool.simple_insert_txn(
+ db_pool.simple_upsert_txn(
txn,
table="event_auth_chain_to_calculate",
+ keyvalues={"event_id": event_id},
values={
- "event_id": event_id,
"room_id": room_id,
"type": e_type,
"state_key": state_key,
@@ -724,7 +809,7 @@ class PersistEventsStore:
break
if not events_to_calc_chain_id_for:
- return
+ return {}
# Allocate chain ID/sequence numbers to each new event.
new_chain_tuples = cls._allocate_chain_ids(
@@ -739,23 +824,10 @@ class PersistEventsStore:
)
chain_map.update(new_chain_tuples)
- db_pool.simple_insert_many_txn(
- txn,
- table="event_auth_chains",
- keys=("event_id", "chain_id", "sequence_number"),
- values=[
- (event_id, c_id, seq)
- for event_id, (c_id, seq) in new_chain_tuples.items()
- ],
- )
-
- db_pool.simple_delete_many_txn(
- txn,
- table="event_auth_chain_to_calculate",
- keyvalues={},
- column="event_id",
- values=new_chain_tuples,
- )
+ to_return = {
+ event_id: NewEventChainLinks(chain_id, sequence_number)
+ for event_id, (chain_id, sequence_number) in new_chain_tuples.items()
+ }
# Now we need to calculate any new links between chains caused by
# the new events.
@@ -825,10 +897,38 @@ class PersistEventsStore:
auth_chain_id, auth_sequence_number = chain_map[auth_id]
# Step 2a, add link between the event and auth event
+ to_return[event_id].links.append((auth_chain_id, auth_sequence_number))
chain_links.add_link(
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
)
+ return to_return
+
+ @classmethod
+ def _persist_chain_cover_index(
+ cls,
+ txn: LoggingTransaction,
+ db_pool: DatabasePool,
+ new_event_links: Dict[str, NewEventChainLinks],
+ ) -> None:
+ db_pool.simple_insert_many_txn(
+ txn,
+ table="event_auth_chains",
+ keys=("event_id", "chain_id", "sequence_number"),
+ values=[
+ (event_id, new_links.chain_id, new_links.sequence_number)
+ for event_id, new_links in new_event_links.items()
+ ],
+ )
+
+ db_pool.simple_delete_many_txn(
+ txn,
+ table="event_auth_chain_to_calculate",
+ keyvalues={},
+ column="event_id",
+ values=new_event_links,
+ )
+
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chain_links",
@@ -838,7 +938,16 @@ class PersistEventsStore:
"target_chain_id",
"target_sequence_number",
),
- values=list(chain_links.get_additions()),
+ values=[
+ (
+ new_links.chain_id,
+ new_links.sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ )
+ for new_links in new_event_links.values()
+ for (target_chain_id, target_sequence_number) in new_links.links
+ ],
)
@staticmethod
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 81feb3ec29..c4e216c308 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -447,7 +447,14 @@ class EventChainStoreTestCase(HomeserverTestCase):
)
# Actually call the function that calculates the auth chain stuff.
- persist_events_store._persist_event_auth_chain_txn(txn, events)
+ new_event_links = (
+ persist_events_store.calculate_chain_cover_index_for_events_txn(
+ txn, events[0].room_id, [e for e in events if e.is_state()]
+ )
+ )
+ persist_events_store._persist_event_auth_chain_txn(
+ txn, events, new_event_links
+ )
self.get_success(
persist_events_store.db_pool.runInteraction(
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 0a6253e22c..088f0d24f9 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -365,12 +365,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
},
)
+ events = [
+ cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
+ for event_id in AUTH_GRAPH
+ ]
+ new_event_links = (
+ self.persist_events.calculate_chain_cover_index_for_events_txn(
+ txn, room_id, [e for e in events if e.is_state()]
+ )
+ )
self.persist_events._persist_event_auth_chain_txn(
txn,
- [
- cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
- for event_id in AUTH_GRAPH
- ],
+ events,
+ new_event_links,
)
self.get_success(
@@ -544,6 +551,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
rooms.
"""
+ # We allow partial covers for this test
+ self.hs.get_datastores().main.tests_allow_no_chain_cover_index = True
+
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
@@ -628,13 +638,20 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
# Insert all events apart from 'B'
+ events = [
+ cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
+ for event_id in auth_graph
+ if event_id != "b"
+ ]
+ new_event_links = (
+ self.persist_events.calculate_chain_cover_index_for_events_txn(
+ txn, room_id, [e for e in events if e.is_state()]
+ )
+ )
self.persist_events._persist_event_auth_chain_txn(
txn,
- [
- cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
- for event_id in auth_graph
- if event_id != "b"
- ],
+ events,
+ new_event_links,
)
# Now we insert the event 'B' without a chain cover, by temporarily
@@ -647,9 +664,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
updatevalues={"has_auth_chain_index": False},
)
+ events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))]
+ new_event_links = (
+ self.persist_events.calculate_chain_cover_index_for_events_txn(
+ txn, room_id, [e for e in events if e.is_state()]
+ )
+ )
self.persist_events._persist_event_auth_chain_txn(
- txn,
- [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
+ txn, events, new_event_links
)
self.store.db_pool.simple_update_txn(
diff --git a/tests/unittest.py b/tests/unittest.py
index 18963b9e32..a7c20556a0 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -344,6 +344,8 @@ class HomeserverTestCase(TestCase):
self._hs_args = {"clock": self.clock, "reactor": self.reactor}
self.hs = self.make_homeserver(self.reactor, self.clock)
+ self.hs.get_datastores().main.tests_allow_no_chain_cover_index = False
+
# Honour the `use_frozen_dicts` config option. We have to do this
# manually because this is taken care of in the app `start` code, which
# we don't run. Plus we want to reset it on tearDown.
|