diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 46042b2bf7..8120c305df 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import attr
from frozendict import frozendict
+from typing_extensions import Literal
from twisted.internet.defer import Deferred
@@ -106,7 +107,7 @@ class EventContext:
incomplete state.
"""
- rejected: Union[bool, str] = False
+ rejected: Union[Literal[False], str] = False
_state_group: Optional[int] = None
state_group_before_event: Optional[int] = None
prev_group: Optional[int] = None
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ad611b2c0b..6c12653bb3 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -49,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.engines.postgres import PostgresEngine
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically
@@ -235,7 +235,9 @@ class PersistEventsStore:
"""
results: List[str] = []
- def _get_events_which_are_prevs_txn(txn, batch):
+ def _get_events_which_are_prevs_txn(
+ txn: LoggingTransaction, batch: Collection[str]
+ ) -> None:
sql = """
SELECT prev_event_id, internal_metadata
FROM event_edges
@@ -285,7 +287,9 @@ class PersistEventsStore:
# and their prev events.
existing_prevs = set()
- def _get_prevs_before_rejected_txn(txn, batch):
+ def _get_prevs_before_rejected_txn(
+ txn: LoggingTransaction, batch: Collection[str]
+ ) -> None:
to_recursively_check = batch
while to_recursively_check:
@@ -515,7 +519,7 @@ class PersistEventsStore:
@classmethod
def _add_chain_cover_index(
cls,
- txn,
+ txn: LoggingTransaction,
db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
@@ -809,7 +813,7 @@ class PersistEventsStore:
@staticmethod
def _allocate_chain_ids(
- txn,
+ txn: LoggingTransaction,
db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
@@ -943,7 +947,7 @@ class PersistEventsStore:
self,
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- ):
+ ) -> None:
"""Persist the mapping from transaction IDs to event IDs (if defined)."""
to_insert = []
@@ -997,7 +1001,7 @@ class PersistEventsStore:
txn: LoggingTransaction,
state_delta_by_room: Dict[str, DeltaState],
stream_id: int,
- ):
+ ) -> None:
for room_id, delta_state in state_delta_by_room.items():
to_delete = delta_state.to_delete
to_insert = delta_state.to_insert
@@ -1155,7 +1159,7 @@ class PersistEventsStore:
txn, room_id, members_changed
)
- def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
+ def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
"""Update the room version in the database based off current state
events.
@@ -1189,7 +1193,7 @@ class PersistEventsStore:
txn: LoggingTransaction,
new_forward_extremities: Dict[str, Set[str]],
max_stream_order: int,
- ):
+ ) -> None:
for room_id in new_forward_extremities.keys():
self.db_pool.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
@@ -1254,9 +1258,9 @@ class PersistEventsStore:
def _update_room_depths_txn(
self,
- txn,
+ txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- ):
+ ) -> None:
"""Update min_depth for each room
Args:
@@ -1385,7 +1389,7 @@ class PersistEventsStore:
# nothing to do here
return
- def event_dict(event):
+ def event_dict(event: EventBase) -> JsonDict:
d = event.get_dict()
d.pop("redacted", None)
d.pop("redacted_because", None)
@@ -1476,18 +1480,20 @@ class PersistEventsStore:
),
)
- def _store_rejected_events_txn(self, txn, events_and_contexts):
+ def _store_rejected_events_txn(
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ ) -> List[Tuple[EventBase, EventContext]]:
"""Add rows to the 'rejections' table for received events which were
rejected
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- events_and_contexts (list[(EventBase, EventContext)]): events
- we are persisting
+ txn: db connection
+ events_and_contexts: events we are persisting
Returns:
- list[(EventBase, EventContext)] new list, without the rejected
- events.
+ new list, without the rejected events.
"""
# Remove the rejected events from the list now that we've added them
# to the events table and the events_json table.
@@ -1508,7 +1514,7 @@ class PersistEventsStore:
events_and_contexts: List[Tuple[EventBase, EventContext]],
all_events_and_contexts: List[Tuple[EventBase, EventContext]],
inhibit_local_membership_updates: bool = False,
- ):
+ ) -> None:
"""Update all the miscellaneous tables for new events
Args:
@@ -1602,7 +1608,11 @@ class PersistEventsStore:
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
- def _add_to_cache(self, txn, events_and_contexts):
+ def _add_to_cache(
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ ) -> None:
to_prefill = []
rows = []
@@ -1633,7 +1643,7 @@ class PersistEventsStore:
if not row["rejects"] and not row["redacts"]:
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
- def prefill():
+ def prefill() -> None:
for cache_entry in to_prefill:
self.store._get_event_cache.set(
(cache_entry.event.event_id,), cache_entry
@@ -1663,19 +1673,24 @@ class PersistEventsStore:
)
def insert_labels_for_event_txn(
- self, txn, event_id, labels, room_id, topological_ordering
- ):
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ labels: List[str],
+ room_id: str,
+ topological_ordering: int,
+ ) -> None:
"""Store the mapping between an event's ID and its labels, with one row per
(event_id, label) tuple.
Args:
- txn (LoggingTransaction): The transaction to execute.
- event_id (str): The event's ID.
- labels (list[str]): A list of text labels.
- room_id (str): The ID of the room the event was sent to.
- topological_ordering (int): The position of the event in the room's topology.
+ txn: The transaction to execute.
+ event_id: The event's ID.
+ labels: A list of text labels.
+ room_id: The ID of the room the event was sent to.
+ topological_ordering: The position of the event in the room's topology.
"""
- return self.db_pool.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_labels",
keys=("event_id", "label", "room_id", "topological_ordering"),
@@ -1684,25 +1699,32 @@ class PersistEventsStore:
],
)
- def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+ def _insert_event_expiry_txn(
+ self, txn: LoggingTransaction, event_id: str, expiry_ts: int
+ ) -> None:
"""Save the expiry timestamp associated with a given event ID.
Args:
- txn (LoggingTransaction): The database transaction to use.
- event_id (str): The event ID the expiry timestamp is associated with.
- expiry_ts (int): The timestamp at which to expire (delete) the event.
+ txn: The database transaction to use.
+ event_id: The event ID the expiry timestamp is associated with.
+ expiry_ts: The timestamp at which to expire (delete) the event.
"""
- return self.db_pool.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn=txn,
table="event_expiry",
values={"event_id": event_id, "expiry_ts": expiry_ts},
)
def _store_room_members_txn(
- self, txn, events, *, inhibit_local_membership_updates: bool = False
- ):
+ self,
+ txn: LoggingTransaction,
+ events: List[EventBase],
+ *,
+ inhibit_local_membership_updates: bool = False,
+ ) -> None:
"""
Store a room member in the database.
+
Args:
txn: The transaction to use.
events: List of events to store.
@@ -1742,6 +1764,7 @@ class PersistEventsStore:
)
for event in events:
+ assert event.internal_metadata.stream_ordering is not None
txn.call_after(
self.store._membership_stream_cache.entity_has_changed,
event.state_key,
@@ -1838,7 +1861,9 @@ class PersistEventsStore:
(parent_id, event.sender),
)
- def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
+ def _handle_insertion_event(
+ self, txn: LoggingTransaction, event: EventBase
+ ) -> None:
"""Handles keeping track of insertion events and edges/connections.
Part of MSC2716.
@@ -1899,7 +1924,7 @@ class PersistEventsStore:
},
)
- def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
+ def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None:
"""Handles inserting the batch edges/connections between the batch event
and an insertion event. Part of MSC2716.
@@ -1999,25 +2024,29 @@ class PersistEventsStore:
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
- def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase):
+ def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
if isinstance(event.content.get("topic"), str):
self.store_event_search_txn(
txn, event, "content.topic", event.content["topic"]
)
- def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase):
+ def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
if isinstance(event.content.get("name"), str):
self.store_event_search_txn(
txn, event, "content.name", event.content["name"]
)
- def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase):
+ def _store_room_message_txn(
+ self, txn: LoggingTransaction, event: EventBase
+ ) -> None:
if isinstance(event.content.get("body"), str):
self.store_event_search_txn(
txn, event, "content.body", event.content["body"]
)
- def _store_retention_policy_for_room_txn(self, txn, event):
+ def _store_retention_policy_for_room_txn(
+ self, txn: LoggingTransaction, event: EventBase
+ ) -> None:
if not event.is_state():
logger.debug("Ignoring non-state m.room.retention event")
return
@@ -2077,8 +2106,11 @@ class PersistEventsStore:
)
def _set_push_actions_for_event_and_users_txn(
- self, txn, events_and_contexts, all_events_and_contexts
- ):
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ all_events_and_contexts: List[Tuple[EventBase, EventContext]],
+ ) -> None:
"""Handles moving push actions from staging table to main
event_push_actions table for all events in `events_and_contexts`.
@@ -2086,12 +2118,10 @@ class PersistEventsStore:
from the push action staging area.
Args:
- events_and_contexts (list[(EventBase, EventContext)]): events
- we are persisting
- all_events_and_contexts (list[(EventBase, EventContext)]): all
- events that we were going to persist. This includes events
- we've already persisted, etc, that wouldn't appear in
- events_and_context.
+ events_and_contexts: events we are persisting
+ all_events_and_contexts: all events that we were going to persist.
+ This includes events we've already persisted, etc, that wouldn't
+ appear in events_and_context.
"""
# Only non outlier events will have push actions associated with them,
@@ -2160,7 +2190,9 @@ class PersistEventsStore:
),
)
- def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+ def _remove_push_actions_for_event_id_txn(
+ self, txn: LoggingTransaction, room_id: str, event_id: str
+ ) -> None:
# Sad that we have to blow away the cache for the whole room here
txn.call_after(
self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
@@ -2171,7 +2203,9 @@ class PersistEventsStore:
(room_id, event_id),
)
- def _store_rejections_txn(self, txn, event_id, reason):
+ def _store_rejections_txn(
+ self, txn: LoggingTransaction, event_id: str, reason: str
+ ) -> None:
self.db_pool.simple_insert_txn(
txn,
table="rejections",
@@ -2183,8 +2217,10 @@ class PersistEventsStore:
)
def _store_event_state_mappings_txn(
- self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
- ):
+ self,
+ txn: LoggingTransaction,
+ events_and_contexts: Collection[Tuple[EventBase, EventContext]],
+ ) -> None:
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
@@ -2241,7 +2277,9 @@ class PersistEventsStore:
state_group_id,
)
- def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+ def _update_min_depth_for_room_txn(
+ self, txn: LoggingTransaction, room_id: str, depth: int
+ ) -> None:
min_depth = self.store._get_min_depth_interaction(txn, room_id)
if min_depth is not None and depth >= min_depth:
@@ -2254,7 +2292,9 @@ class PersistEventsStore:
values={"min_depth": depth},
)
- def _handle_mult_prev_events(self, txn, events):
+ def _handle_mult_prev_events(
+ self, txn: LoggingTransaction, events: List[EventBase]
+ ) -> None:
"""
For the given event, update the event edges table and forward and
backward extremities tables.
@@ -2272,7 +2312,9 @@ class PersistEventsStore:
self._update_backward_extremeties(txn, events)
- def _update_backward_extremeties(self, txn, events):
+ def _update_backward_extremeties(
+ self, txn: LoggingTransaction, events: List[EventBase]
+ ) -> None:
"""Updates the event_backward_extremities tables based on the new/updated
events being persisted.
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 3c49e7ec98..78e0773b2a 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,7 +14,7 @@
import logging
import re
-from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
import attr
@@ -27,7 +27,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
if TYPE_CHECKING:
@@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
)
- async def _background_reindex_search(self, progress, batch_size):
+ async def _background_reindex_search(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
# we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
- def reindex_search_txn(txn):
+ def reindex_search_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id, room_id, type, json, "
" origin_server_ts FROM events"
@@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
return result
- async def _background_reindex_gin_search(self, progress, batch_size):
+ async def _background_reindex_gin_search(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""This handles old synapses which used GIST indexes, if any;
converting them back to be GIN as per the actual schema.
"""
- def create_index(conn):
+ def create_index(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
# we have to set autocommit, because postgres refuses to
@@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
)
return 1
- async def _background_reindex_search_order(self, progress, batch_size):
+ async def _background_reindex_search_order(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
if not have_added_index:
- def create_index(conn):
+ def create_index(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
conn.set_session(autocommit=True)
c = conn.cursor()
@@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
pg,
)
- def reindex_search_txn(txn):
+ def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
sql = (
"UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
" origin_server_ts = e.origin_server_ts"
@@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore):
else:
raise Exception("Unrecognized database engine")
- args.append(limit)
+ # mypy expects to append only a `str`, not an `int`
+ args.append(limit) # type: ignore[arg-type]
results = await self.db_pool.execute(
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
@@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore):
A set of strings.
"""
- def f(txn):
+ def f(txn: LoggingTransaction) -> Set[str]:
highlight_words = set()
for event in events:
# As a hack we simply join values of all possible keys. This is
@@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore):
return await self.db_pool.runInteraction("_find_highlights", f)
-def _to_postgres_options(options_dict):
+def _to_postgres_options(options_dict: JsonDict) -> str:
return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
-def _parse_query(database_engine, search_term):
+def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
|