diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c6bf316d5b..c7b660ac5a 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -15,14 +15,18 @@
import logging
import threading
from typing import (
+ TYPE_CHECKING,
+ Any,
Collection,
Container,
Dict,
Iterable,
List,
+ NoReturn,
Optional,
Set,
Tuple,
+ cast,
overload,
)
@@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
+ RoomVersion,
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
@@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdTracker,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
@@ -69,10 +82,13 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# These values are used in the `enqueue_event` and `_fetch_loop` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
# on jki.re
@@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
@attr.s(slots=True, auto_attribs=True)
-class _EventCacheEntry:
+class EventCacheEntry:
event: EventBase
redacted_event: Optional[EventBase]
@@ -129,7 +145,7 @@ class _EventRow:
json: str
internal_metadata: str
format_version: Optional[int]
- room_version_id: Optional[int]
+ room_version_id: Optional[str]
rejected_reason: Optional[str]
redactions: List[str]
outlier: bool
@@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
# options controlling this.
USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
+ self._stream_id_gen: AbstractStreamIdTracker
+ self._backfill_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
@@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
5 * 60 * 1000,
)
- self._get_event_cache = LruCache(
+ self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
)
@@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
# ID to cache entry. Note that the returned dict may not have the
# requested event in it if the event isn't in the DB.
self._current_event_fetches: Dict[
- str, ObservableDeferred[Dict[str, _EventCacheEntry]]
+ str, ObservableDeferred[Dict[str, EventCacheEntry]]
] = {}
self._event_fetch_lock = threading.Condition()
- self._event_fetch_list = []
+ self._event_fetch_list: List[
+ Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
+ ] = []
self._event_fetch_ongoing = 0
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
- def get_chain_id_txn(txn):
+ def get_chain_id_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
- return txn.fetchone()[0]
+ return cast(Tuple[int], txn.fetchone())[0]
self.event_chain_id_gen = build_sequence_generator(
db_conn,
@@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id",
)
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
@@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = False,
- allow_rejected: bool = False,
- allow_none: Literal[False] = False,
- check_room_id: Optional[str] = None,
+ get_prev_content: bool = ...,
+ allow_rejected: bool = ...,
+ allow_none: Literal[False] = ...,
+ check_room_id: Optional[str] = ...,
) -> EventBase:
...
@@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = False,
- allow_rejected: bool = False,
- allow_none: Literal[True] = False,
- check_room_id: Optional[str] = None,
+ get_prev_content: bool = ...,
+ allow_rejected: bool = ...,
+ allow_none: Literal[True] = ...,
+ check_room_id: Optional[str] = ...,
) -> Optional[EventBase]:
...
@@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_events(
self,
- event_ids: Iterable[str],
+ event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
@@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
- ) -> Dict[str, _EventCacheEntry]:
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
# same dict into itself N times).
already_fetching_ids: Set[str] = set()
already_fetching_deferreds: Set[
- ObservableDeferred[Dict[str, _EventCacheEntry]]
+ ObservableDeferred[Dict[str, EventCacheEntry]]
] = set()
for event_id in missing_events_ids:
@@ -601,8 +632,8 @@ class EventsWorkerStore(SQLBaseStore):
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
- Dict[str, _EventCacheEntry]
- ] = ObservableDeferred(defer.Deferred())
+ Dict[str, EventCacheEntry]
+ ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
@@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
- def _invalidate_get_event_cache(self, event_id):
+ def _invalidate_get_event_cache(self, event_id: str) -> None:
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
- ) -> Dict[str, _EventCacheEntry]:
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch events from the caches.
May return rejected events.
@@ -736,38 +767,123 @@ class EventsWorkerStore(SQLBaseStore):
for e in state_to_include.values()
]
- def _do_fetch(self, conn: Connection) -> None:
+ def _maybe_start_fetch_thread(self) -> None:
+ """Starts an event fetch thread if we are not yet at the maximum number."""
+ with self._event_fetch_lock:
+ if (
+ self._event_fetch_list
+ and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
+ ):
+ self._event_fetch_ongoing += 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+ # `_event_fetch_ongoing` is decremented in `_fetch_thread`.
+ should_start = True
+ else:
+ should_start = False
+
+ if should_start:
+ run_as_background_process("fetch_events", self._fetch_thread)
+
+ async def _fetch_thread(self) -> None:
+ """Services requests for events from `_event_fetch_list`."""
+ exc = None
+ try:
+ await self.db_pool.runWithConnection(self._fetch_loop)
+ except BaseException as e:
+ exc = e
+ raise
+ finally:
+ should_restart = False
+ event_fetches_to_fail = []
+ with self._event_fetch_lock:
+ self._event_fetch_ongoing -= 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+
+ # There may still be work remaining in `_event_fetch_list` if we
+ # failed, or it was added in between us deciding to exit and
+ # decrementing `_event_fetch_ongoing`.
+ if self._event_fetch_list:
+ if exc is None:
+ # We decided to exit, but then some more work was added
+ # before `_event_fetch_ongoing` was decremented.
+ # If a new event fetch thread was not started, we should
+ # restart ourselves since the remaining event fetch threads
+ # may take a while to get around to the new work.
+ #
+ # Unfortunately it is not possible to tell whether a new
+ # event fetch thread was started, so we restart
+ # unconditionally. If we are unlucky, we will end up with
+ # an idle fetch thread, but it will time out after
+ # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
+ # in any case.
+ #
+ # Note that multiple fetch threads may run down this path at
+ # the same time.
+ should_restart = True
+ elif isinstance(exc, Exception):
+ if self._event_fetch_ongoing == 0:
+ # We were the last remaining fetcher and failed.
+ # Fail any outstanding fetches since no one else will
+ # handle them.
+ event_fetches_to_fail = self._event_fetch_list
+ self._event_fetch_list = []
+ else:
+ # We weren't the last remaining fetcher, so another
+ # fetcher will pick up the work. This will either happen
+ # after their existing work, however long that takes,
+ # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
+ # they are idle.
+ pass
+ else:
+ # The exception is a `SystemExit`, `KeyboardInterrupt` or
+ # `GeneratorExit`. Don't try to do anything clever here.
+ pass
+
+ if should_restart:
+ # We exited cleanly but noticed more work.
+ self._maybe_start_fetch_thread()
+
+ if event_fetches_to_fail:
+ # We were the last remaining fetcher and failed.
+ # Fail any outstanding fetches since no one else will handle them.
+ assert exc is not None
+ with PreserveLoggingContext():
+ for _, deferred in event_fetches_to_fail:
+ deferred.errback(exc)
+
+ def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
- try:
- i = 0
- while True:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if (
- not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
- or single_threaded
- or i > EVENT_QUEUE_ITERATIONS
- ):
- break
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
+ i = 0
+ while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ # There are no requests waiting. If we haven't yet reached the
+ # maximum iteration limit, wait for some more requests to turn up.
+ # Otherwise, bail out.
+ single_threaded = self.database_engine.single_threaded
+ if (
+ not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+ or single_threaded
+ or i > EVENT_QUEUE_ITERATIONS
+ ):
+ return
+
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
- self._fetch_event_list(conn, event_list)
- finally:
- self._event_fetch_ongoing -= 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+ self._fetch_event_list(conn, event_list)
def _fetch_event_list(
- self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+ self,
+ conn: LoggingDatabaseConnection,
+ event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
) -> None:
"""Handle a load of requests from the _event_fetch_list queue
@@ -794,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
)
# We only want to resolve deferreds from the main thread
- def fire():
+ def fire() -> None:
for _, d in event_list:
d.callback(row_dict)
@@ -804,18 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
- def fire(evs, exc):
- for _, d in evs:
- if not d.called:
- with PreserveLoggingContext():
- d.errback(exc)
+ def fire_errback(exc: Exception) -> None:
+ for _, d in event_list:
+ d.errback(exc)
with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list, e)
+ self.hs.get_reactor().callFromThread(fire_errback, e)
async def _get_events_from_db(
- self, event_ids: Iterable[str]
- ) -> Dict[str, _EventCacheEntry]:
+ self, event_ids: Collection[str]
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the database.
May return rejected events.
@@ -831,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
map from event id to result. May return extra events which
weren't asked for.
"""
- fetched_events = {}
+ fetched_event_ids: Set[str] = set()
+ fetched_events: Dict[str, _EventRow] = {}
events_to_fetch = event_ids
while events_to_fetch:
row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
- redaction_ids = set()
+ redaction_ids: Set[str] = set()
for event_id in events_to_fetch:
row = row_map.get(event_id)
- fetched_events[event_id] = row
+ fetched_event_ids.add(event_id)
if row:
+ fetched_events[event_id] = row
redaction_ids.update(row.redactions)
- events_to_fetch = redaction_ids.difference(fetched_events.keys())
+ events_to_fetch = redaction_ids.difference(fetched_event_ids)
if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch)
# build a map from event_id to EventBase
- event_map = {}
+ event_map: Dict[str, EventBase] = {}
for event_id, row in fetched_events.items():
- if not row:
- continue
assert row.event_id == event_id
rejected_reason = row.rejected_reason
@@ -881,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row.room_version_id
+ room_version: Optional[RoomVersion]
if not room_version_id:
# this should only happen for out-of-band membership events which
# arrived before #6983 landed. For all other events, we should have
@@ -951,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
# finally, we can decide whether each one needs redacting, and build
# the cache entries.
- result_map = {}
+ result_map: Dict[str, EventCacheEntry] = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
- cache_entry = _EventCacheEntry(
+ cache_entry = EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)
@@ -967,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
+ async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -980,23 +1095,12 @@ class EventsWorkerStore(SQLBaseStore):
that weren't requested.
"""
- events_d = defer.Deferred()
+ events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append((events, events_d))
-
self._event_fetch_lock.notify()
- if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
- self._event_fetch_ongoing += 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
- should_start = True
- else:
- should_start = False
-
- if should_start:
- run_as_background_process(
- "fetch_events", self.db_pool.runWithConnection, self._do_fetch
- )
+ self._maybe_start_fetch_thread()
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
@@ -1146,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
- async def have_events_in_timeline(self, event_ids):
+ async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
@@ -1175,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
event_ids: events we are looking for
Returns:
- set[str]: The events we have already seen.
+ The set of events we have already seen.
"""
res = await self._have_seen_events_dict(
(room_id, event_id) for event_id in event_ids
@@ -1198,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
}
results = {x: True for x in cache_results}
- def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
+ def have_seen_events_txn(
+ txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
+ ) -> None:
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@@ -1224,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
- async def have_seen_event(self, room_id: str, event_id: str):
+ async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
# this only exists for the benefit of the @cachedList descriptor on
# _have_seen_events_dict
raise NotImplementedError()
- def _get_current_state_event_counts_txn(self, txn, room_id):
+ def _get_current_state_event_counts_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> int:
"""
See get_current_state_event_counts.
"""
@@ -1254,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
- async def get_room_complexity(self, room_id):
+ async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@@ -1262,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
more resources.
Args:
- room_id (str)
+ room_id: The room ID to query.
Returns:
- dict[str:int] of complexity version to complexity.
+ dict[str:float] of complexity version to complexity.
"""
state_events = await self.get_current_state_event_counts(room_id)
@@ -1275,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
return {"v1": complexity_v1}
- def get_current_events_token(self):
+ def get_current_events_token(self) -> int:
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> List[Tuple]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns new events, for the Events replication stream
Args:
@@ -1295,13 +1403,15 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_all_new_forward_event_rows(txn):
+ def get_all_new_forward_event_rows(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
@@ -1311,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, instance_name, limit))
- return txn.fetchall()
+ return cast(
+ List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
@@ -1319,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_ex_outlier_stream_rows(
self, instance_name: str, last_id: int, current_id: int
- ) -> List[Tuple]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns de-outliered events, for the Events replication stream
Args:
@@ -1332,14 +1444,16 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_ex_outlier_stream_rows_txn(txn):
+ def get_ex_outlier_stream_rows_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
@@ -1350,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (last_id, current_id, instance_name))
- return txn.fetchall()
+ return cast(
+ List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@@ -1358,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_new_backfill_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, list]], int, bool]:
+ ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
"""Get updates for backfill replication stream, including all new
backfilled events and events that have gone from being outliers to not.
@@ -1386,13 +1502,15 @@ class EventsWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_new_backfill_event_rows(txn):
+ def get_all_new_backfill_event_rows(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
+ " se.state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" AND instance_name = ?"
@@ -1400,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
- new_event_updates = [(row[0], row[1:]) for row in txn]
+ new_event_updates: List[
+ Tuple[int, Tuple[str, str, str, str, str, str]]
+ ] = []
+ row: Tuple[int, str, str, str, str, str, str]
+ # Type safety: iterating over `txn` yields `Tuple`, i.e.
+ # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+ # variadic tuple to a fixed length tuple and flags it up as an error.
+ for row in txn: # type: ignore[assignment]
+ new_event_updates.append((row[0], row[1:]))
limited = False
if len(new_event_updates) == limit:
@@ -1411,11 +1537,11 @@ class EventsWorkerStore(SQLBaseStore):
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
+ " se.state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
@@ -1423,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound, instance_name))
- new_event_updates.extend((row[0], row[1:]) for row in txn)
+ # Type safety: iterating over `txn` yields `Tuple`, i.e.
+ # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+ # variadic tuple to a fixed length tuple and flags it up as an error.
+ for row in txn: # type: ignore[assignment]
+ new_event_updates.append((row[0], row[1:]))
if len(new_event_updates) >= limit:
upper_bound = new_event_updates[-1][0]
@@ -1437,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_updated_current_state_deltas(
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
- ) -> Tuple[List[Tuple], int, bool]:
+ ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
"""Fetch updates from current_state_delta_stream
Args:
@@ -1457,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
* `limited` is whether there are more updates to fetch.
"""
- def get_all_updated_current_state_deltas_txn(txn):
+ def get_all_updated_current_state_deltas_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
@@ -1466,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
- def get_deltas_for_stream_id_txn(txn, stream_id):
+ def get_deltas_for_stream_id_txn(
+ txn: LoggingTransaction, stream_id: int
+ ) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
- rows: List[Tuple] = await self.db_pool.runInteraction(
+ rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
@@ -1509,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
- async def is_event_after(self, event_id1, event_id2):
+ async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
"""Returns True if event_id1 is after event_id2 in the stream"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cached(max_entries=5000)
- async def get_event_ordering(self, event_id):
+ async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
@@ -1539,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
None otherwise.
"""
- def get_next_event_to_expire_txn(txn):
+ def get_next_event_to_expire_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, int]]:
txn.execute(
"""
SELECT event_id, expiry_ts FROM event_expiry
@@ -1547,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
)
- return txn.fetchone()
+ return cast(Optional[Tuple[str, int]], txn.fetchone())
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@@ -1611,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
return mapping
@wrap_as_background_process("_cleanup_old_transaction_ids")
- async def _cleanup_old_transaction_ids(self):
+ async def _cleanup_old_transaction_ids(self) -> None:
"""Cleans out transaction id mappings older than 24hrs."""
- def _cleanup_old_transaction_ids_txn(txn):
+ def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?
@@ -1626,3 +1762,198 @@ class EventsWorkerStore(SQLBaseStore):
"_cleanup_old_transaction_ids",
_cleanup_old_transaction_ids_txn,
)
+
+ async def is_event_next_to_backward_gap(self, event: EventBase) -> bool:
+ """Check if the given event is next to a backward gap of missing events.
+ <latest messages> A(False)--->B(False)--->C(True)---> <gap, unknown events> <oldest messages>
+
+ Args:
+ room_id: room where the event lives
+ event_id: event to check
+
+ Returns:
+ Boolean indicating whether it's an extremity
+ """
+
+ def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool:
+ # If the event in question has any of its prev_events listed as a
+ # backward extremity, it's next to a gap.
+ #
+ # We can't just check the backward edges in `event_edges` because
+ # when we persist events, we will also record the prev_events as
+ # edges to the event in question regardless of whether we have those
+ # prev_events yet. We need to check whether those prev_events are
+ # backward extremities, also known as gaps, that need to be
+ # backfilled.
+ backward_extremity_query = """
+ SELECT 1 FROM event_backward_extremities
+ WHERE
+ room_id = ?
+ AND %s
+ LIMIT 1
+ """
+
+ # If the event in question is a backward extremity or has any of its
+ # prev_events listed as a backward extremity, it's next to a
+ # backward gap.
+ clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "event_id",
+ [event.event_id] + list(event.prev_event_ids()),
+ )
+
+ txn.execute(backward_extremity_query % (clause,), [event.room_id] + args)
+ backward_extremities = txn.fetchall()
+
+ # We consider any backward extremity as a backward gap
+ if len(backward_extremities):
+ return True
+
+ return False
+
+ return await self.db_pool.runInteraction(
+ "is_event_next_to_backward_gap_txn",
+ is_event_next_to_backward_gap_txn,
+ )
+
+ async def is_event_next_to_forward_gap(self, event: EventBase) -> bool:
+ """Check if the given event is next to a forward gap of missing events.
+ The gap in front of the latest events is not considered a gap.
+ <latest messages> A(False)--->B(False)--->C(False)---> <gap, unknown events> <oldest messages>
+ <latest messages> A(False)--->B(False)---> <gap, unknown events> --->D(True)--->E(False) <oldest messages>
+
+ Args:
+ room_id: room where the event lives
+ event_id: event to check
+
+ Returns:
+ Boolean indicating whether it's an extremity
+ """
+
+ def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool:
+ # If the event in question is a forward extremity, we will just
+ # consider any potential forward gap as not a gap since it's one of
+ # the latest events in the room.
+ #
+ # `event_forward_extremities` does not include backfilled or outlier
+ # events so we can't rely on it to find forward gaps. We can only
+ # use it to determine whether a message is the latest in the room.
+ #
+ # We can't combine this query with the `forward_edge_query` below
+ # because if the event in question has no forward edges (isn't
+ # referenced by any other event's prev_events) but is in
+ # `event_forward_extremities`, we don't want to return 0 rows and
+ # say it's next to a gap.
+ forward_extremity_query = """
+ SELECT 1 FROM event_forward_extremities
+ WHERE
+ room_id = ?
+ AND event_id = ?
+ LIMIT 1
+ """
+
+ # Check to see whether the event in question is already referenced
+ # by another event. If we don't see any edges, we're next to a
+ # forward gap.
+ forward_edge_query = """
+ SELECT 1 FROM event_edges
+ /* Check to make sure the event referencing our event in question is not rejected */
+ LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
+ WHERE
+ event_edges.room_id = ?
+ AND event_edges.prev_event_id = ?
+ /* It's not a valid edge if the event referencing our event in
+ * question is rejected.
+ */
+ AND rejections.event_id IS NULL
+ LIMIT 1
+ """
+
+ # We consider any forward extremity as the latest in the room and
+ # not a forward gap.
+ #
+ # To expand, even though there is technically a gap at the front of
+ # the room where the forward extremities are, we consider those the
+ # latest messages in the room so asking other homeservers for more
+ # is useless. The new latest messages will just be federated as
+ # usual.
+ txn.execute(forward_extremity_query, (event.room_id, event.event_id))
+ forward_extremities = txn.fetchall()
+ if len(forward_extremities):
+ return False
+
+ # If there are no forward edges to the event in question (another
+ # event hasn't referenced this event in their prev_events), then we
+ # assume there is a forward gap in the history.
+ txn.execute(forward_edge_query, (event.room_id, event.event_id))
+ forward_edges = txn.fetchall()
+ if not len(forward_edges):
+ return True
+
+ return False
+
+ return await self.db_pool.runInteraction(
+ "is_event_next_to_gap_txn",
+ is_event_next_to_gap_txn,
+ )
+
+ async def get_event_id_for_timestamp(
+ self, room_id: str, timestamp: int, direction: str
+ ) -> Optional[str]:
+ """Find the closest event to the given timestamp in the given direction.
+
+ Args:
+ room_id: Room to fetch the event from
+ timestamp: The point in time (inclusive) we should navigate from in
+ the given direction to find the closest event.
+ direction: ["f"|"b"] to indicate whether we should navigate forward
+ or backward from the given timestamp to find the closest event.
+
+ Returns:
+ The closest event_id otherwise None if we can't find any event in
+ the given direction.
+ """
+
+ sql_template = """
+ SELECT event_id FROM events
+ LEFT JOIN rejections USING (event_id)
+ WHERE
+ origin_server_ts %s ?
+ AND room_id = ?
+ /* Make sure event is not rejected */
+ AND rejections.event_id IS NULL
+ ORDER BY origin_server_ts %s
+ LIMIT 1;
+ """
+
+ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
+ if direction == "b":
+ # Find closest event *before* a given timestamp. We use descending
+ # (which gives values largest to smallest) because we want the
+ # largest possible timestamp *before* the given timestamp.
+ comparison_operator = "<="
+ order = "DESC"
+ else:
+ # Find closest event *after* a given timestamp. We use ascending
+ # (which gives values smallest to largest) because we want the
+ # closest possible timestamp *after* the given timestamp.
+ comparison_operator = ">="
+ order = "ASC"
+
+ txn.execute(
+ sql_template % (comparison_operator, order), (timestamp, room_id)
+ )
+ row = txn.fetchone()
+ if row:
+ (event_id,) = row
+ return event_id
+
+ return None
+
+ if direction not in ("f", "b"):
+ raise ValueError("Unknown direction: %s" % (direction,))
+
+ return await self.db_pool.runInteraction(
+ "get_event_id_for_timestamp_txn",
+ get_event_id_for_timestamp_txn,
+ )
|