diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 9580a40785..270b30800b 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -24,7 +24,11 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.events import EventBase, make_event_from_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine
@@ -62,7 +66,12 @@ class _NoChainCoverIndex(Exception):
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
if hs.config.worker.run_background_tasks:
@@ -279,7 +288,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
new_front = set()
for chunk in batch_iter(front, 100):
# Pull the auth events either from the cache or DB.
- to_fetch = [] # Event IDs to fetch from DB # type: List[str]
+ to_fetch: List[str] = [] # Event IDs to fetch from DB
for event_id in chunk:
res = self._event_auth_cache.get(event_id)
if res is None:
@@ -606,8 +615,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# currently walking, either from cache or DB.
search, chunk = search[:-100], search[-100:]
- found = [] # Results found # type: List[Tuple[str, str, int]]
- to_fetch = [] # Event IDs to fetch from DB # type: List[str]
+ found: List[Tuple[str, str, int]] = [] # Results found
+ to_fetch: List[str] = [] # Event IDs to fetch from DB
for _, event_id in chunk:
res = self._event_auth_cache.get(event_id)
if res is None:
@@ -1384,7 +1393,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
count = await self.db_pool.simple_select_one_onecol(
table="federation_inbound_events_staging",
keyvalues={"room_id": room_id},
- retcol="COALESCE(COUNT(*), 0)",
+ retcol="COUNT(*)",
desc="prune_staged_events_in_room_count",
)
@@ -1476,9 +1485,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""Update the prometheus metrics for the inbound federation staging area."""
def _get_stats_for_federation_staging_txn(txn):
- txn.execute(
- "SELECT coalesce(count(*), 0) FROM federation_inbound_events_staging"
- )
+ txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
(count,) = txn.fetchone()
txn.execute(
@@ -1514,7 +1521,12 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
|