diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 8c1bf9227a..fa132d10b4 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -14,10 +14,18 @@
from typing import List, Optional, Tuple
from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.util.id_generators import _load_current_id
+from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id
-class SlavedIdTracker:
+class SlavedIdTracker(AbstractStreamIdTracker):
+ """Tracks the "current" stream ID of a stream with a single writer.
+
+ See `AbstractStreamIdTracker` for more details.
+
+ Note that this class does not work correctly when there are multiple
+ writers.
+ """
+
def __init__(
self,
db_conn: LoggingDatabaseConnection,
@@ -36,17 +44,7 @@ class SlavedIdTracker:
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self) -> int:
- """
-
- Returns:
- int
- """
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer.
-
- For streams with single writers this is equivalent to
- `get_current_token`.
- """
return self.get_current_token()
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 4d5f862862..7541e21de9 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@@ -25,9 +24,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- # We assert this for the benefit of mypy
- assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
-
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index a030e9299e..a390cfcb74 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -14,7 +14,7 @@
# limitations under the License.
import heapq
from collections.abc import Iterable
-from typing import TYPE_CHECKING, List, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Optional, Tuple, Type
import attr
@@ -157,7 +157,7 @@ class EventsStream(Stream):
# now we fetch up to that many rows from the events table
- event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
+ event_rows = await self._store.get_all_new_forward_event_rows(
instance_name, from_token, current_token, target_row_count
)
@@ -191,7 +191,7 @@ class EventsStream(Stream):
# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.
- ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
+ ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
instance_name, from_token, upper_limit
)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1605411b00..446204dbe5 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -764,7 +764,7 @@ class StateResolutionStore:
store: "DataStore"
def get_events(
- self, event_ids: Iterable[str], allow_rejected: bool = False
+ self, event_ids: Collection[str], allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 6edadea550..499a328201 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -17,6 +17,7 @@ import logging
from typing import (
Awaitable,
Callable,
+ Collection,
Dict,
Iterable,
List,
@@ -44,7 +45,7 @@ async def resolve_events_with_store(
room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
- state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
+ state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
) -> StateMap[str]:
"""
Args:
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 0623da9aa1..3056e64ff5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -21,7 +21,7 @@ from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
-from synapse.types import StreamToken, get_domain_from_id
+from synapse.types import get_domain_from_id
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -48,7 +48,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self,
stream_name: str,
instance_name: str,
- token: StreamToken,
+ token: int,
rows: Iterable[Any],
) -> None:
pass
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 06832221ad..c3440de2cb 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -15,7 +15,7 @@
# limitations under the License.
import itertools
import logging
-from collections import OrderedDict, namedtuple
+from collections import OrderedDict
from typing import (
TYPE_CHECKING,
Any,
@@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
@@ -64,9 +65,6 @@ event_counter = Counter(
)
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-
-
@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -108,16 +106,21 @@ class PersistEventsStore:
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
- # Ideally we'd move these ID gens here, unfortunately some other ID
- # generators are chained off them so doing so is a bit of a PITA.
- self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
- self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
-
# This should only exist on instances that are configured to write
assert (
hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
+ # Since we have been configured to write, we ought to have id generators,
+ # rather than id trackers.
+ assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
+ assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
+
+ # Ideally we'd move these ID gens here, unfortunately some other ID
+ # generators are chained off them so doing so is a bit of a PITA.
+ self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
+ self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
@@ -1553,11 +1556,13 @@ class PersistEventsStore:
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
- to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
+ to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
def prefill():
for cache_entry in to_prefill:
- self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
+ self.store._get_event_cache.set(
+ (cache_entry.event.event_id,), cache_entry
+ )
txn.call_after(prefill)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c6bcfe1c32..4cefc0a07e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -15,14 +15,18 @@
import logging
import threading
from typing import (
+ TYPE_CHECKING,
+ Any,
Collection,
Container,
Dict,
Iterable,
List,
+ NoReturn,
Optional,
Set,
Tuple,
+ cast,
overload,
)
@@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
+ RoomVersion,
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
@@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdTracker,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
@@ -69,6 +82,9 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
@attr.s(slots=True, auto_attribs=True)
-class _EventCacheEntry:
+class EventCacheEntry:
event: EventBase
redacted_event: Optional[EventBase]
@@ -129,7 +145,7 @@ class _EventRow:
json: str
internal_metadata: str
format_version: Optional[int]
- room_version_id: Optional[int]
+ room_version_id: Optional[str]
rejected_reason: Optional[str]
redactions: List[str]
outlier: bool
@@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
# options controlling this.
USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
+ self._stream_id_gen: AbstractStreamIdTracker
+ self._backfill_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
@@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
5 * 60 * 1000,
)
- self._get_event_cache = LruCache(
+ self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
)
@@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
# ID to cache entry. Note that the returned dict may not have the
# requested event in it if the event isn't in the DB.
self._current_event_fetches: Dict[
- str, ObservableDeferred[Dict[str, _EventCacheEntry]]
+ str, ObservableDeferred[Dict[str, EventCacheEntry]]
] = {}
self._event_fetch_lock = threading.Condition()
- self._event_fetch_list = []
+ self._event_fetch_list: List[
+ Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
+ ] = []
self._event_fetch_ongoing = 0
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
- def get_chain_id_txn(txn):
+ def get_chain_id_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
- return txn.fetchone()[0]
+ return cast(Tuple[int], txn.fetchone())[0]
self.event_chain_id_gen = build_sequence_generator(
db_conn,
@@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id",
)
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
@@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = False,
- allow_rejected: bool = False,
- allow_none: Literal[False] = False,
- check_room_id: Optional[str] = None,
+ get_prev_content: bool = ...,
+ allow_rejected: bool = ...,
+ allow_none: Literal[False] = ...,
+ check_room_id: Optional[str] = ...,
) -> EventBase:
...
@@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = False,
- allow_rejected: bool = False,
- allow_none: Literal[True] = False,
- check_room_id: Optional[str] = None,
+ get_prev_content: bool = ...,
+ allow_rejected: bool = ...,
+ allow_none: Literal[True] = ...,
+ check_room_id: Optional[str] = ...,
) -> Optional[EventBase]:
...
@@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_events(
self,
- event_ids: Iterable[str],
+ event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
@@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
- ) -> Dict[str, _EventCacheEntry]:
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
# same dict into itself N times).
already_fetching_ids: Set[str] = set()
already_fetching_deferreds: Set[
- ObservableDeferred[Dict[str, _EventCacheEntry]]
+ ObservableDeferred[Dict[str, EventCacheEntry]]
] = set()
for event_id in missing_events_ids:
@@ -601,7 +632,7 @@ class EventsWorkerStore(SQLBaseStore):
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
- Dict[str, _EventCacheEntry]
+ Dict[str, EventCacheEntry]
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
@@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
- def _invalidate_get_event_cache(self, event_id):
+ def _invalidate_get_event_cache(self, event_id: str) -> None:
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
- ) -> Dict[str, _EventCacheEntry]:
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch events from the caches.
May return rejected events.
@@ -820,7 +851,7 @@ class EventsWorkerStore(SQLBaseStore):
for _, deferred in event_fetches_to_fail:
deferred.errback(exc)
- def _fetch_loop(self, conn: Connection) -> None:
+ def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
@@ -850,7 +881,9 @@ class EventsWorkerStore(SQLBaseStore):
self._fetch_event_list(conn, event_list)
def _fetch_event_list(
- self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
+ self,
+ conn: LoggingDatabaseConnection,
+ event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
) -> None:
"""Handle a load of requests from the _event_fetch_list queue
@@ -877,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
)
# We only want to resolve deferreds from the main thread
- def fire():
+ def fire() -> None:
for _, d in event_list:
d.callback(row_dict)
@@ -887,16 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
- def fire(evs, exc):
- for _, d in evs:
+ def fire_errback(exc: Exception) -> None:
+ for _, d in event_list:
d.errback(exc)
with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list, e)
+ self.hs.get_reactor().callFromThread(fire_errback, e)
async def _get_events_from_db(
- self, event_ids: Iterable[str]
- ) -> Dict[str, _EventCacheEntry]:
+ self, event_ids: Collection[str]
+ ) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the database.
May return rejected events.
@@ -912,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
map from event id to result. May return extra events which
weren't asked for.
"""
- fetched_events = {}
+ fetched_event_ids: Set[str] = set()
+ fetched_events: Dict[str, _EventRow] = {}
events_to_fetch = event_ids
while events_to_fetch:
row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
- redaction_ids = set()
+ redaction_ids: Set[str] = set()
for event_id in events_to_fetch:
row = row_map.get(event_id)
- fetched_events[event_id] = row
+ fetched_event_ids.add(event_id)
if row:
+ fetched_events[event_id] = row
redaction_ids.update(row.redactions)
- events_to_fetch = redaction_ids.difference(fetched_events.keys())
+ events_to_fetch = redaction_ids.difference(fetched_event_ids)
if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch)
# build a map from event_id to EventBase
- event_map = {}
+ event_map: Dict[str, EventBase] = {}
for event_id, row in fetched_events.items():
- if not row:
- continue
assert row.event_id == event_id
rejected_reason = row.rejected_reason
@@ -962,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row.room_version_id
+ room_version: Optional[RoomVersion]
if not room_version_id:
# this should only happen for out-of-band membership events which
# arrived before #6983 landed. For all other events, we should have
@@ -1032,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
# finally, we can decide whether each one needs redacting, and build
# the cache entries.
- result_map = {}
+ result_map: Dict[str, EventCacheEntry] = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
- cache_entry = _EventCacheEntry(
+ cache_entry = EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)
@@ -1048,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
+ async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -1061,7 +1095,7 @@ class EventsWorkerStore(SQLBaseStore):
that weren't requested.
"""
- events_d = defer.Deferred()
+ events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append((events, events_d))
self._event_fetch_lock.notify()
@@ -1216,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
- async def have_events_in_timeline(self, event_ids):
+ async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
@@ -1245,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
event_ids: events we are looking for
Returns:
- set[str]: The events we have already seen.
+ The set of events we have already seen.
"""
res = await self._have_seen_events_dict(
(room_id, event_id) for event_id in event_ids
@@ -1268,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
}
results = {x: True for x in cache_results}
- def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
+ def have_seen_events_txn(
+ txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
+ ) -> None:
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@@ -1294,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
- async def have_seen_event(self, room_id: str, event_id: str):
+ async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
# this only exists for the benefit of the @cachedList descriptor on
# _have_seen_events_dict
raise NotImplementedError()
- def _get_current_state_event_counts_txn(self, txn, room_id):
+ def _get_current_state_event_counts_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> int:
"""
See get_current_state_event_counts.
"""
@@ -1324,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
- async def get_room_complexity(self, room_id):
+ async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@@ -1332,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
more resources.
Args:
- room_id (str)
+ room_id: The room ID to query.
Returns:
- dict[str:int] of complexity version to complexity.
+ dict[str:float] of complexity version to complexity.
"""
state_events = await self.get_current_state_event_counts(room_id)
@@ -1345,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
return {"v1": complexity_v1}
- def get_current_events_token(self):
+ def get_current_events_token(self) -> int:
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> List[Tuple]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns new events, for the Events replication stream
Args:
@@ -1365,7 +1403,9 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_all_new_forward_event_rows(txn):
+ def get_all_new_forward_event_rows(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
@@ -1381,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, instance_name, limit))
- return txn.fetchall()
+ return cast(
+ List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
@@ -1389,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_ex_outlier_stream_rows(
self, instance_name: str, last_id: int, current_id: int
- ) -> List[Tuple]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns de-outliered events, for the Events replication stream
Args:
@@ -1402,7 +1444,9 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_ex_outlier_stream_rows_txn(txn):
+ def get_ex_outlier_stream_rows_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
@@ -1420,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (last_id, current_id, instance_name))
- return txn.fetchall()
+ return cast(
+ List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ )
return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@@ -1428,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_new_backfill_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, list]], int, bool]:
+ ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
"""Get updates for backfill replication stream, including all new
backfilled events and events that have gone from being outliers to not.
@@ -1456,7 +1502,9 @@ class EventsWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_new_backfill_event_rows(txn):
+ def get_all_new_backfill_event_rows(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
@@ -1470,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
- new_event_updates = [(row[0], row[1:]) for row in txn]
+ new_event_updates: List[
+ Tuple[int, Tuple[str, str, str, str, str, str]]
+ ] = []
+ row: Tuple[int, str, str, str, str, str, str]
+ # Type safety: iterating over `txn` yields `Tuple`, i.e.
+ # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+ # variadic tuple to a fixed length tuple and flags it up as an error.
+ for row in txn: # type: ignore[assignment]
+ new_event_updates.append((row[0], row[1:]))
limited = False
if len(new_event_updates) == limit:
@@ -1493,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound, instance_name))
- new_event_updates.extend((row[0], row[1:]) for row in txn)
+ # Type safety: iterating over `txn` yields `Tuple`, i.e.
+ # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
+ # variadic tuple to a fixed length tuple and flags it up as an error.
+ for row in txn: # type: ignore[assignment]
+ new_event_updates.append((row[0], row[1:]))
if len(new_event_updates) >= limit:
upper_bound = new_event_updates[-1][0]
@@ -1507,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_updated_current_state_deltas(
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
- ) -> Tuple[List[Tuple], int, bool]:
+ ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
"""Fetch updates from current_state_delta_stream
Args:
@@ -1527,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
* `limited` is whether there are more updates to fetch.
"""
- def get_all_updated_current_state_deltas_txn(txn):
+ def get_all_updated_current_state_deltas_txn(
+ txn: LoggingTransaction,
+ ) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
@@ -1536,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
- def get_deltas_for_stream_id_txn(txn, stream_id):
+ def get_deltas_for_stream_id_txn(
+ txn: LoggingTransaction, stream_id: int
+ ) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
- return txn.fetchall()
+ return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
- rows: List[Tuple] = await self.db_pool.runInteraction(
+ rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
@@ -1579,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
- async def is_event_after(self, event_id1, event_id2):
+ async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
"""Returns True if event_id1 is after event_id2 in the stream"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cached(max_entries=5000)
- async def get_event_ordering(self, event_id):
+ async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
@@ -1609,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
None otherwise.
"""
- def get_next_event_to_expire_txn(txn):
+ def get_next_event_to_expire_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, int]]:
txn.execute(
"""
SELECT event_id, expiry_ts FROM event_expiry
@@ -1617,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
)
- return txn.fetchone()
+ return cast(Optional[Tuple[str, int]], txn.fetchone())
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@@ -1681,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
return mapping
@wrap_as_background_process("_cleanup_old_transaction_ids")
- async def _cleanup_old_transaction_ids(self):
+ async def _cleanup_old_transaction_ids(self) -> None:
"""Cleans out transaction id mappings older than 24hrs."""
- def _cleanup_old_transaction_ids_txn(txn):
+ def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index fa782023d4..3b63267395 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdTracker,
+ StreamIdGenerator,
+)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +85,9 @@ class PushRulesWorkerStore(
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen: Union[
- StreamIdGenerator, SlavedIdTracker
- ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
+ self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ db_conn, "push_rules_stream", "stream_id"
+ )
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ac56bc9a05..4ff3013908 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -89,31 +89,77 @@ def _load_current_id(
return (max if step > 0 else min)(current_id, step)
-class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
- @abc.abstractmethod
- def get_next(self) -> AsyncContextManager[int]:
- raise NotImplementedError()
+class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
+ """Tracks the "current" stream ID of a stream that may have multiple writers.
+
+ Stream IDs are monotonically increasing or decreasing integers representing write
+ transactions. The "current" stream ID is the stream ID such that all transactions
+ with equal or smaller stream IDs have completed. Since transactions may complete out
+ of order, this is not the same as the stream ID of the last completed transaction.
+
+ Completed transactions include both committed transactions and transactions that
+ have been rolled back.
+ """
@abc.abstractmethod
- def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+ def advance(self, instance_name: str, new_id: int) -> None:
+ """Advance the position of the named writer to the given ID, if greater
+ than existing entry.
+ """
raise NotImplementedError()
@abc.abstractmethod
def get_current_token(self) -> int:
+ """Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
+
+ Returns:
+ The maximum stream id.
+ """
raise NotImplementedError()
@abc.abstractmethod
def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
+
+ For streams with single writers this is equivalent to `get_current_token`.
+ """
+ raise NotImplementedError()
+
+
+class AbstractStreamIdGenerator(AbstractStreamIdTracker):
+ """Generates stream IDs for a stream that may have multiple writers.
+
+ Each stream ID represents a write transaction, whose completion is tracked
+ so that the "current" stream ID of the stream can be determined.
+
+ See `AbstractStreamIdTracker` for more details.
+ """
+
+ @abc.abstractmethod
+ def get_next(self) -> AsyncContextManager[int]:
+ """
+ Usage:
+ async with stream_id_gen.get_next() as stream_id:
+ # ... persist event ...
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+ """
+ Usage:
+ async with stream_id_gen.get_next(n) as stream_ids:
+ # ... persist events ...
+ """
raise NotImplementedError()
class StreamIdGenerator(AbstractStreamIdGenerator):
- """Used to generate new stream ids when persisting events while keeping
- track of which transactions have been completed.
+ """Generates and tracks stream IDs for a stream with a single writer.
- This allows us to get the "current" stream id, i.e. the stream id such that
- all ids less than or equal to it have completed. This handles the fact that
- persistence of events can complete out of order.
+ This class must only be used when the current Synapse process is the sole
+ writer for a stream.
Args:
db_conn(connection): A database connection to use to fetch the
@@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
+ def advance(self, instance_name: str, new_id: int) -> None:
+ # `StreamIdGenerator` should only be used when there is a single writer,
+ # so replication should never happen.
+ raise Exception("Replication is not supported by StreamIdGenerator")
+
def get_next(self) -> AsyncContextManager[int]:
- """
- Usage:
- async with stream_id_gen.get_next() as stream_id:
- # ... persist event ...
- """
with self._lock:
self._current += self._step
next_id = self._current
@@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
- """
- Usage:
- async with stream_id_gen.get_next(n) as stream_ids:
- # ... persist events ...
- """
with self._lock:
next_ids = range(
self._current + self._step,
@@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int:
- """Returns the maximum stream id such that all stream ids less than or
- equal to it have been successfully persisted.
-
- Returns:
- The maximum stream id.
- """
with self._lock:
if self._unfinished_ids:
return next(iter(self._unfinished_ids)) - self._step
@@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer.
-
- For streams with single writers this is equivalent to
- `get_current_token`.
- """
return self.get_current_token()
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
- """An ID generator that tracks a stream that can have multiple writers.
+ """Generates and tracks stream IDs for a stream with multiple writers.
Uses a Postgres sequence to coordinate ID assignment, but positions of other
writers will only get updated when `advance` is called (by replication).
@@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return stream_ids
def get_next(self) -> AsyncContextManager[int]:
- """
- Usage:
- async with stream_id_gen.get_next() as stream_id:
- # ... persist event ...
- """
-
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
- """
- Usage:
- async with stream_id_gen.get_next_mult(5) as stream_ids:
- # ... persist events ...
- """
-
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._add_persisted_position(next_id)
def get_current_token(self) -> int:
- """Returns the maximum stream id such that all stream ids less than or
- equal to it have been successfully persisted.
- """
-
return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer."""
-
# If we don't have an entry for the given instance name, we assume it's a
# new writer.
#
@@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
}
def advance(self, instance_name: str, new_id: int) -> None:
- """Advance the position of the named writer to the given ID, if greater
- than existing entry.
- """
-
new_id *= self._return_factor
with self._lock:
|