diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 29de7e5bed..fbfd748406 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -53,6 +53,7 @@ class RoomBatchHandler:
# We want to use the successor event depth so they appear after `prev_event` because
# it has a larger `depth` but before the successor event because the `stream_ordering`
# is negative before the successor event.
+ assert most_recent_prev_event_id is not None
successor_event_ids = await self.store.get_successor_events(
most_recent_prev_event_id
)
@@ -139,6 +140,7 @@ class RoomBatchHandler:
_,
) = await self.store.get_max_depth_of(event_ids)
# mapping from (type, state_key) -> state_event_id
+ assert most_recent_event_id is not None
prev_state_map = await self.state_store.get_state_ids_for_event(
most_recent_event_id
)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4710224708..dcfe8caf47 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -14,7 +14,17 @@
import itertools
import logging
from queue import Empty, PriorityQueue
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
import attr
from prometheus_client import Counter, Gauge
@@ -33,7 +43,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@@ -135,7 +145,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
- room = await self.get_room(room_id)
+ room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
@@ -158,7 +168,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
def _get_auth_chain_ids_using_cover_index_txn(
- self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ event_ids: Collection[str],
+ include_given: bool,
) -> Set[str]:
"""Calculates the auth chain IDs using the chain index."""
@@ -215,9 +229,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
chains: Dict[int, int] = {}
# Add all linked chains reachable from initial set of chains.
- for batch in batch_iter(event_chains, 1000):
+ for batch2 in batch_iter(event_chains, 1000):
clause, args = make_in_list_sql_clause(
- txn.database_engine, "origin_chain_id", batch
+ txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
@@ -297,7 +311,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
front = set(event_ids)
while front:
- new_front = set()
+ new_front: Set[str] = set()
for chunk in batch_iter(front, 100):
# Pull the auth events either from the cache or DB.
to_fetch: List[str] = [] # Event IDs to fetch from DB
@@ -316,7 +330,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Note we need to batch up the results by event ID before
# adding to the cache.
- to_cache = {}
+ to_cache: Dict[str, List[Tuple[str, int]]] = {}
for event_id, auth_event_id, auth_event_depth in txn:
to_cache.setdefault(event_id, []).append(
(auth_event_id, auth_event_depth)
@@ -349,7 +363,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
- room = await self.get_room(room_id)
+ room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
@@ -370,7 +384,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
def _get_auth_chain_difference_using_cover_index_txn(
- self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
+ self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
) -> Set[str]:
"""Calculates the auth chain difference using the chain index.
@@ -444,9 +458,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop)
- for batch in batch_iter(set(seen_chains), 1000):
+ for batch2 in batch_iter(set(seen_chains), 1000):
clause, args = make_in_list_sql_clause(
- txn.database_engine, "origin_chain_id", batch
+ txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
@@ -529,7 +543,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return result
def _get_auth_chain_difference_txn(
- self, txn, state_sets: List[Set[str]]
+ self, txn: LoggingTransaction, state_sets: List[Set[str]]
) -> Set[str]:
"""Calculates the auth chain difference using a breadth first search.
@@ -602,7 +616,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# I think building a temporary list with fetchall is more efficient than
# just `search.extend(txn)`, but this is unconfirmed
- search.extend(txn.fetchall())
+ search.extend(cast(List[Tuple[int, str]], txn.fetchall()))
# sort by depth
search.sort()
@@ -645,7 +659,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# We parse the results and add the to the `found` set and the
# cache (note we need to batch up the results by event ID before
# adding to the cache).
- to_cache = {}
+ to_cache: Dict[str, List[Tuple[str, int]]] = {}
for event_id, auth_event_id, auth_event_depth in txn:
to_cache.setdefault(event_id, []).append(
(auth_event_id, auth_event_depth)
@@ -696,7 +710,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return {eid for eid, n in event_to_missing_sets.items() if n}
async def get_oldest_event_ids_with_depth_in_room(
- self, room_id
+ self, room_id: str
) -> List[Tuple[str, int]]:
"""Gets the oldest events(backwards extremities) in the room along with the
aproximate depth.
@@ -713,7 +727,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
List of (event_id, depth) tuples
"""
- def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id):
+ def get_oldest_event_ids_with_depth_in_room_txn(
+ txn: LoggingTransaction, room_id: str
+ ) -> List[Tuple[str, int]]:
# Assemble a dictionary with event_id -> depth for the oldest events
# we know of in the room. Backwards extremeties are the oldest
# events we know of in the room but we only know of them because
@@ -743,7 +759,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (room_id, False))
- return txn.fetchall()
+ return cast(List[Tuple[str, int]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_oldest_event_ids_with_depth_in_room",
@@ -752,7 +768,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
async def get_insertion_event_backward_extremities_in_room(
- self, room_id
+ self, room_id: str
) -> List[Tuple[str, int]]:
"""Get the insertion events we know about that we haven't backfilled yet.
@@ -768,7 +784,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
List of (event_id, depth) tuples
"""
- def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
+ def get_insertion_event_backward_extremities_in_room_txn(
+ txn: LoggingTransaction, room_id: str
+ ) -> List[Tuple[str, int]]:
sql = """
SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
/* We only want insertion events that are also marked as backwards extremities */
@@ -780,7 +798,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
txn.execute(sql, (room_id,))
- return txn.fetchall()
+ return cast(List[Tuple[str, int]], txn.fetchall())
return await self.db_pool.runInteraction(
"get_insertion_event_backward_extremities_in_room",
@@ -788,7 +806,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)
- async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+ async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs
Args:
@@ -817,7 +835,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return max_depth_event_id, current_max_depth
- async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
+ async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the min depth from a set of event IDs
Args:
@@ -865,7 +883,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
)
- def _get_prev_events_for_room_txn(self, txn, room_id: str):
+ def _get_prev_events_for_room_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> List[str]:
# we just use the 10 newest events. Older events will become
# prev_events of future events.
@@ -896,7 +916,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
sorted by extremity count.
"""
- def _get_rooms_with_many_extremities_txn(txn):
+ def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]:
where_clause = "1=1"
if room_id_filter:
where_clause = "room_id NOT IN (%s)" % (
@@ -937,7 +957,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"get_min_depth", self._get_min_depth_interaction, room_id
)
- def _get_min_depth_interaction(self, txn, room_id):
+ def _get_min_depth_interaction(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> Optional[int]:
min_depth = self.db_pool.simple_select_one_onecol_txn(
txn,
table="room_depth",
@@ -966,22 +988,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
- last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
+ last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined]
# We don't always have a full stream_to_exterm_id table, e.g. after
# the upgrade that introduced it, so we make sure we never ask for a
# stream_ordering from before a restart
- last_change = max(self._stream_order_on_start, last_change)
+ last_change = max(self._stream_order_on_start, last_change) # type: ignore[attr-defined]
# provided the last_change is recent enough, we now clamp the requested
# stream_ordering to it.
- if last_change > self.stream_ordering_month_ago:
+ if last_change > self.stream_ordering_month_ago: # type: ignore[attr-defined]
stream_ordering = min(last_change, stream_ordering)
return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
@cached(max_entries=5000, num_args=2)
- async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
+ async def _get_forward_extremeties_for_room(
+ self, room_id: str, stream_ordering: int
+ ) -> List[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -989,7 +1013,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
stream_orderings from that point.
"""
- if stream_ordering <= self.stream_ordering_month_ago:
+ if stream_ordering <= self.stream_ordering_month_ago: # type: ignore[attr-defined]
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
sql = """
@@ -1002,7 +1026,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
WHERE room_id = ?
"""
- def get_forward_extremeties_for_room_txn(txn):
+ def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
@@ -1104,8 +1128,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
]
async def get_backfill_events(
- self, room_id: str, seed_event_id_list: list, limit: int
- ):
+ self, room_id: str, seed_event_id_list: List[str], limit: int
+ ) -> List[EventBase]:
"""Get a list of Events for a given topic that occurred before (and
including) the events in seed_event_id_list. Return a list of max size `limit`
@@ -1123,10 +1147,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
events = await self.get_events_as_list(event_ids)
return sorted(
- events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
+ # type-ignore: mypy doesn't like negating the Optional[int] stream_ordering.
+ # But it's never None, because these events were previously persisted to the DB.
+ events,
+ key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering), # type: ignore[operator]
)
- def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit):
+ def _get_backfill_events(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ seed_event_id_list: List[str],
+ limit: int,
+ ) -> Set[str]:
"""
We want to make sure that we do a breadth-first, "depth" ordered search.
We also handle navigating historical branches of history connected by
@@ -1139,7 +1172,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
limit,
)
- event_id_results = set()
+ event_id_results: Set[str] = set()
# In a PriorityQueue, the lowest valued entries are retrieved first.
# We're using depth as the priority in the queue and tie-break based on
@@ -1147,7 +1180,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# highest and newest-in-time message. We add events to the queue with a
# negative depth so that we process the newest-in-time messages first
# going backwards in time. stream_ordering follows the same pattern.
- queue = PriorityQueue()
+ queue: "PriorityQueue[Tuple[int, int, str, str]]" = PriorityQueue()
for seed_event_id in seed_event_id_list:
event_lookup_result = self.db_pool.simple_select_one_txn(
@@ -1253,7 +1286,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return event_id_results
- async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
+ async def get_missing_events(
+ self,
+ room_id: str,
+ earliest_events: List[str],
+ latest_events: List[str],
+ limit: int,
+ ) -> List[EventBase]:
ids = await self.db_pool.runInteraction(
"get_missing_events",
self._get_missing_events,
@@ -1264,11 +1303,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
return await self.get_events_as_list(ids)
- def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
+ def _get_missing_events(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ earliest_events: List[str],
+ latest_events: List[str],
+ limit: int,
+ ) -> List[str]:
seen_events = set(earliest_events)
front = set(latest_events) - seen_events
- event_results = []
+ event_results: List[str] = []
query = (
"SELECT prev_event_id FROM event_edges "
@@ -1311,7 +1357,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@wrap_as_background_process("delete_old_forward_extrem_cache")
async def _delete_old_forward_extrem_cache(self) -> None:
- def _delete_old_forward_extrem_cache_txn(txn):
+ def _delete_old_forward_extrem_cache_txn(txn: LoggingTransaction) -> None:
# Delete entries older than a month, while making sure we don't delete
# the only entries for a room.
sql = """
@@ -1324,7 +1370,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) AND stream_ordering < ?
"""
txn.execute(
- sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
+ sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) # type: ignore[attr-defined]
)
await self.db_pool.runInteraction(
@@ -1382,7 +1428,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
if self.db_pool.engine.supports_returning:
- def _remove_received_event_from_staging_txn(txn):
+ def _remove_received_event_from_staging_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[int]:
sql = """
DELETE FROM federation_inbound_events_staging
WHERE origin = ? AND event_id = ?
@@ -1390,21 +1438,24 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
txn.execute(sql, (origin, event_id))
- return txn.fetchone()
+ row = cast(Optional[Tuple[int]], txn.fetchone())
- row = await self.db_pool.runInteraction(
+ if row is None:
+ return None
+
+ return row[0]
+
+ return await self.db_pool.runInteraction(
"remove_received_event_from_staging",
_remove_received_event_from_staging_txn,
db_autocommit=True,
)
- if row is None:
- return None
-
- return row[0]
else:
- def _remove_received_event_from_staging_txn(txn):
+ def _remove_received_event_from_staging_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[int]:
received_ts = self.db_pool.simple_select_one_onecol_txn(
txn,
table="federation_inbound_events_staging",
@@ -1437,7 +1488,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) -> Optional[Tuple[str, str]]:
"""Get the next event ID in the staging area for the given room."""
- def _get_next_staged_event_id_for_room_txn(txn):
+ def _get_next_staged_event_id_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, str]]:
sql = """
SELECT origin, event_id
FROM federation_inbound_events_staging
@@ -1448,7 +1501,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (room_id,))
- return txn.fetchone()
+ return cast(Optional[Tuple[str, str]], txn.fetchone())
return await self.db_pool.runInteraction(
"get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn
@@ -1461,7 +1514,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) -> Optional[Tuple[str, EventBase]]:
"""Get the next event in the staging area for the given room."""
- def _get_next_staged_event_for_room_txn(txn):
+ def _get_next_staged_event_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, str, str]]:
sql = """
SELECT event_json, internal_metadata, origin
FROM federation_inbound_events_staging
@@ -1471,7 +1526,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
txn.execute(sql, (room_id,))
- return txn.fetchone()
+ return cast(Optional[Tuple[str, str, str]], txn.fetchone())
row = await self.db_pool.runInteraction(
"get_next_staged_event_for_room", _get_next_staged_event_for_room_txn
@@ -1599,18 +1654,20 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
@wrap_as_background_process("_get_stats_for_federation_staging")
- async def _get_stats_for_federation_staging(self):
+ async def _get_stats_for_federation_staging(self) -> None:
"""Update the prometheus metrics for the inbound federation staging area."""
- def _get_stats_for_federation_staging_txn(txn):
+ def _get_stats_for_federation_staging_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[int, int]:
txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
txn.execute(
"SELECT min(received_ts) FROM federation_inbound_events_staging"
)
- (received_ts,) = txn.fetchone()
+ (received_ts,) = cast(Tuple[Optional[int]], txn.fetchone())
# If there is nothing in the staging area default it to 0.
age = 0
@@ -1651,19 +1708,21 @@ class EventFederationStore(EventFederationWorkerStore):
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
- async def clean_room_for_join(self, room_id):
- return await self.db_pool.runInteraction(
+ async def clean_room_for_join(self, room_id: str) -> None:
+ await self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)
- def _clean_room_for_join_txn(self, txn, room_id):
+ def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None:
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
- async def _background_delete_non_state_event_auth(self, progress, batch_size):
- def delete_event_auth(txn):
+ async def _background_delete_non_state_event_auth(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ def delete_event_auth(txn: LoggingTransaction) -> bool:
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
max_stream_id = progress.get("max_stream_id_exclusive")
|