diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index a6fda3f43c..1e731d56bd 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -19,6 +19,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
+import collections
import itertools
import logging
from collections import OrderedDict
@@ -53,6 +54,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
+from synapse.storage.databases.main.event_federation import EventFederationStore
from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.engines import PostgresEngine
@@ -768,40 +770,26 @@ class PersistEventsStore:
# that have the same chain ID as the event.
# 2. For each retained auth event we:
# a. Add a link from the event's to the auth event's chain
- # ID/sequence number; and
- # b. Add a link from the event to every chain reachable by the
- # auth event.
+ # ID/sequence number
# Step 1, fetch all existing links from all the chains we've seen
# referenced.
chain_links = _LinkMap()
- auth_chain_rows = cast(
- List[Tuple[int, int, int, int]],
- db_pool.simple_select_many_txn(
- txn,
- table="event_auth_chain_links",
- column="origin_chain_id",
- iterable={chain_id for chain_id, _ in chain_map.values()},
- keyvalues={},
- retcols=(
- "origin_chain_id",
- "origin_sequence_number",
- "target_chain_id",
- "target_sequence_number",
- ),
- ),
- )
- for (
- origin_chain_id,
- origin_sequence_number,
- target_chain_id,
- target_sequence_number,
- ) in auth_chain_rows:
- chain_links.add_link(
- (origin_chain_id, origin_sequence_number),
- (target_chain_id, target_sequence_number),
- new=False,
- )
+
+ for links in EventFederationStore._get_chain_links(
+ txn, {chain_id for chain_id, _ in chain_map.values()}
+ ):
+ for origin_chain_id, inner_links in links.items():
+ for (
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in inner_links:
+ chain_links.add_link(
+ (origin_chain_id, origin_sequence_number),
+ (target_chain_id, target_sequence_number),
+ new=False,
+ )
# We do this in toplogical order to avoid adding redundant links.
for event_id in sorted_topologically(
@@ -836,18 +824,6 @@ class PersistEventsStore:
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
)
- # Step 2b, add a link to chains reachable from the auth
- # event.
- for target_id, target_seq in chain_links.get_links_from(
- (auth_chain_id, auth_sequence_number)
- ):
- if target_id == chain_id:
- continue
-
- chain_links.add_link(
- (chain_id, sequence_number), (target_id, target_seq)
- )
-
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chain_links",
@@ -2451,31 +2427,6 @@ class _LinkMap:
current_links[src_seq] = target_seq
return True
- def get_links_from(
- self, src_tuple: Tuple[int, int]
- ) -> Generator[Tuple[int, int], None, None]:
- """Gets the chains reachable from the given chain/sequence number.
-
- Yields:
- The chain ID and sequence number the link points to.
- """
- src_chain, src_seq = src_tuple
- for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
- for link_src_seq, target_seq in sequence_numbers.items():
- if link_src_seq <= src_seq:
- yield target_id, target_seq
-
- def get_links_between(
- self, source_chain: int, target_chain: int
- ) -> Generator[Tuple[int, int], None, None]:
- """Gets the links between two chains.
-
- Yields:
- The source and target sequence numbers.
- """
-
- yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
-
def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
"""Gets any newly added links.
@@ -2502,9 +2453,24 @@ class _LinkMap:
if src_chain == target_chain:
return target_seq <= src_seq
- links = self.get_links_between(src_chain, target_chain)
- for link_start_seq, link_end_seq in links:
- if link_start_seq <= src_seq and target_seq <= link_end_seq:
- return True
+ # We have to graph traverse the links to check for indirect paths.
+ visited_chains = collections.Counter()
+ search = [(src_chain, src_seq)]
+ while search:
+ chain, seq = search.pop()
+ visited_chains[chain] = max(seq, visited_chains[chain])
+ for tc, links in self.maps.get(chain, {}).items():
+ for ss, ts in links.items():
+ # Don't revisit chains we've already seen, unless the target
+ # sequence number is higher than last time.
+ if ts <= visited_chains.get(tc, 0):
+ continue
+
+ if ss <= seq:
+ if tc == target_chain:
+ if target_seq <= ts:
+ return True
+ else:
+ search.append((tc, ts))
return False
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index c0b925444f..039aa91b92 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -132,12 +132,16 @@ Changes in SCHEMA_VERSION = 82
Changes in SCHEMA_VERSION = 83
- The event_txn_id is no longer used.
+
+Changes in SCHEMA_VERSION = 84
+ - No longer assumes that `event_auth_chain_links` holds transitive links, and
+ so read operations must do graph traversal.
"""
SCHEMA_COMPAT_VERSION = (
- # The event_txn_id table and tables from MSC2716 no longer exist.
- 83
+ # Transitive links are no longer written to `event_auth_chain_links`
+ 84
)
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
|