diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 259cae5b37..e22aa0b9bc 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -123,9 +123,9 @@ class DataStore(
RelationsStore,
CensorEventsStore,
UIAuthStore,
+ EventForwardExtremitiesStore,
CacheInvalidationWorkerStore,
ServerMetricsStore,
- EventForwardExtremitiesStore,
LockStore,
SessionStore,
):
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
index 6d2688d711..68901b4335 100644
--- a/synapse/storage/databases/main/events_forward_extremities.py
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -13,15 +13,20 @@
# limitations under the License.
import logging
-from typing import Dict, List
+from typing import Any, Dict, List
from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
+from synapse.storage.databases.main import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
logger = logging.getLogger(__name__)
-class EventForwardExtremitiesStore(SQLBaseStore):
+class EventForwardExtremitiesStore(
+ EventFederationWorkerStore,
+ CacheInvalidationWorkerStore,
+):
async def delete_forward_extremities_for_room(self, room_id: str) -> int:
"""Delete any extra forward extremities for a room.
@@ -31,7 +36,7 @@ class EventForwardExtremitiesStore(SQLBaseStore):
Returns count deleted.
"""
- def delete_forward_extremities_for_room_txn(txn):
+ def delete_forward_extremities_for_room_txn(txn: LoggingTransaction) -> int:
# First we need to get the event_id to not delete
sql = """
SELECT event_id FROM event_forward_extremities
@@ -82,10 +87,14 @@ class EventForwardExtremitiesStore(SQLBaseStore):
delete_forward_extremities_for_room_txn,
)
- async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
+ async def get_forward_extremities_for_room(
+ self, room_id: str
+ ) -> List[Dict[str, Any]]:
"""Get list of forward extremities for a room."""
- def get_forward_extremities_for_room_txn(txn):
+ def get_forward_extremities_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, Any]]:
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
|