diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 4a883dc166..baec35ee27 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -143,7 +143,7 @@ class ApplicationServiceTransactionWorkerStore(
A list of ApplicationServices, which may be empty.
"""
results = await self.db_pool.simple_select_list(
- "application_services_state", {"state": state.value}, ["as_id"]
+ "application_services_state", {"state": state}, ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
@@ -173,7 +173,7 @@ class ApplicationServiceTransactionWorkerStore(
desc="get_appservice_state",
)
if result:
- return ApplicationServiceState(result.get("state"))
+ return result.get("state")
return None
async def set_appservice_state(
@@ -186,7 +186,7 @@ class ApplicationServiceTransactionWorkerStore(
state: The connectivity state to apply.
"""
await self.db_pool.simple_upsert(
- "application_services_state", {"as_id": service.id}, {"state": state.value}
+ "application_services_state", {"as_id": service.id}, {"state": state}
)
async def create_appservice_txn(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d5a4a661cd..9ccc66e589 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -139,27 +139,6 @@ class DeviceWorkerStore(SQLBaseStore):
return {d["device_id"]: d for d in devices}
- async def get_devices_by_auth_provider_session_id(
- self, auth_provider_id: str, auth_provider_session_id: str
- ) -> List[Dict[str, Any]]:
- """Retrieve the list of devices associated with a SSO IdP session ID.
-
- Args:
- auth_provider_id: The SSO IdP ID as defined in the server config
- auth_provider_session_id: The session ID within the IdP
- Returns:
- A list of dicts containing the device_id and the user_id of each device
- """
- return await self.db_pool.simple_select_list(
- table="device_auth_providers",
- keyvalues={
- "auth_provider_id": auth_provider_id,
- "auth_provider_session_id": auth_provider_session_id,
- },
- retcols=("user_id", "device_id"),
- desc="get_devices_by_auth_provider_session_id",
- )
-
@trace
async def get_device_updates_by_remote(
self, destination: str, from_stream_id: int, limit: int
@@ -1091,12 +1070,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def store_device(
- self,
- user_id: str,
- device_id: str,
- initial_device_display_name: Optional[str],
- auth_provider_id: Optional[str] = None,
- auth_provider_session_id: Optional[str] = None,
+ self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
) -> bool:
"""Ensure the given device is known; add it to the store if not
@@ -1105,8 +1079,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id: id of device
initial_device_display_name: initial displayname of the device.
Ignored if device exists.
- auth_provider_id: The SSO IdP the user used, if any.
- auth_provider_session_id: The session ID (sid) got from a OIDC login.
Returns:
Whether the device was inserted or an existing device existed with that ID.
@@ -1143,18 +1115,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if hidden:
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
- if auth_provider_id and auth_provider_session_id:
- await self.db_pool.simple_insert(
- "device_auth_providers",
- values={
- "user_id": user_id,
- "device_id": device_id,
- "auth_provider_id": auth_provider_id,
- "auth_provider_session_id": auth_provider_session_id,
- },
- desc="store_device_auth_provider",
- )
-
self.device_id_exists_cache.set(key, True)
return inserted
except StoreError:
@@ -1208,14 +1168,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
keyvalues={"user_id": user_id},
)
- self.db_pool.simple_delete_many_txn(
- txn,
- table="device_auth_providers",
- column="device_id",
- values=device_ids,
- keyvalues={"user_id": user_id},
- )
-
await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 9580a40785..ef5d1ef01e 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1552,9 +1552,9 @@ class EventFederationStore(EventFederationWorkerStore):
DELETE FROM event_auth
WHERE event_id IN (
SELECT event_id FROM events
- LEFT JOIN state_events AS se USING (room_id, event_id)
+ LEFT JOIN state_events USING (room_id, event_id)
WHERE ? <= stream_ordering AND stream_ordering < ?
- AND se.state_key IS null
+ AND state_key IS null
)
"""
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3efdd0c920..d957e770dc 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -16,7 +16,6 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr
-from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -38,20 +37,6 @@ DEFAULT_HIGHLIGHT_ACTION = [
]
-class BasePushAction(TypedDict):
- event_id: str
- actions: List[Union[dict, str]]
-
-
-class HttpPushAction(BasePushAction):
- room_id: str
- stream_ordering: int
-
-
-class EmailPushAction(HttpPushAction):
- received_ts: Optional[int]
-
-
def _serialize_action(actions, is_highlight):
"""Custom serializer for actions. This allows us to "compress" common actions.
@@ -236,7 +221,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
- ) -> List[HttpPushAction]:
+ ) -> List[dict]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the httppusher.
@@ -341,7 +326,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
- ) -> List[EmailPushAction]:
+ ) -> List[dict]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the emailpusher
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 4e528612ea..06832221ad 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -15,7 +15,7 @@
# limitations under the License.
import itertools
import logging
-from collections import OrderedDict
+from collections import OrderedDict, namedtuple
from typing import (
TYPE_CHECKING,
Any,
@@ -41,10 +41,9 @@ from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
-from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import AbstractStreamIdGenerator
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
@@ -65,6 +64,9 @@ event_counter = Counter(
)
+_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+
+
@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -106,30 +108,23 @@ class PersistEventsStore:
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
+ # Ideally we'd move these ID gens here, unfortunately some other ID
+ # generators are chained off them so doing so is a bit of a PITA.
+ self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
+ self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
+
# This should only exist on instances that are configured to write
assert (
hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
- # Since we have been configured to write, we ought to have id generators,
- # rather than id trackers.
- assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
- assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
-
- # Ideally we'd move these ID gens here, unfortunately some other ID
- # generators are chained off them so doing so is a bit of a PITA.
- self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
- self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
-
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- *,
current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
- use_negative_stream_ordering: bool = False,
- inhibit_local_membership_updates: bool = False,
+ backfilled: bool = False,
) -> None:
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -142,14 +137,7 @@ class PersistEventsStore:
room state
new_forward_extremities: Map from room_id to list of event IDs
that are the new forward extremities of the room.
- use_negative_stream_ordering: Whether to start stream_ordering on
- the negative side and decrement. This should be set as True
- for backfilled events because backfilled events get a negative
- stream ordering so they don't come down incremental `/sync`.
- inhibit_local_membership_updates: Stop the local_current_membership
- from being updated by these events. This should be set to True
- for backfilled events because backfilled events in the past do
- not affect the current local state.
+ backfilled
Returns:
Resolves when the events have been persisted
@@ -171,7 +159,7 @@ class PersistEventsStore:
#
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
- if use_negative_stream_ordering:
+ if backfilled:
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
@@ -188,13 +176,13 @@ class PersistEventsStore:
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
- inhibit_local_membership_updates=inhibit_local_membership_updates,
+ backfilled=backfilled,
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc(len(events_and_contexts))
- if stream < 0:
+ if not backfilled:
# backfilled events have negative stream orderings, so we don't
# want to set the event_persisted_position to that.
synapse.metrics.event_persisted_position.set(
@@ -328,9 +316,8 @@ class PersistEventsStore:
def _persist_events_txn(
self,
txn: LoggingTransaction,
- *,
events_and_contexts: List[Tuple[EventBase, EventContext]],
- inhibit_local_membership_updates: bool = False,
+ backfilled: bool,
state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
):
@@ -343,10 +330,7 @@ class PersistEventsStore:
Args:
txn
events_and_contexts: events to persist
- inhibit_local_membership_updates: Stop the local_current_membership
- from being updated by these events. This should be set to True
- for backfilled events because backfilled events in the past do
- not affect the current local state.
+ backfilled: True if the events were backfilled
delete_existing True to purge existing table rows for the events
from the database. This is useful when retrying due to
IntegrityError.
@@ -379,7 +363,9 @@ class PersistEventsStore:
events_and_contexts
)
- self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
+ self._update_room_depths_txn(
+ txn, events_and_contexts=events_and_contexts, backfilled=backfilled
+ )
# _update_outliers_txn filters out any events which have already been
# persisted, and returns the filtered list.
@@ -412,7 +398,7 @@ class PersistEventsStore:
txn,
events_and_contexts=events_and_contexts,
all_events_and_contexts=all_events_and_contexts,
- inhibit_local_membership_updates=inhibit_local_membership_updates,
+ backfilled=backfilled,
)
# We call this last as it assumes we've inserted the events into
@@ -575,9 +561,9 @@ class PersistEventsStore:
# fetch their auth event info.
while missing_auth_chains:
sql = """
- SELECT event_id, events.type, se.state_key, chain_id, sequence_number
+ SELECT event_id, events.type, state_key, chain_id, sequence_number
FROM events
- INNER JOIN state_events AS se USING (event_id)
+ INNER JOIN state_events USING (event_id)
LEFT JOIN event_auth_chains USING (event_id)
WHERE
"""
@@ -1214,6 +1200,7 @@ class PersistEventsStore:
self,
txn,
events_and_contexts: List[Tuple[EventBase, EventContext]],
+ backfilled: bool,
):
"""Update min_depth for each room
@@ -1221,18 +1208,13 @@ class PersistEventsStore:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
+ backfilled (bool): True if the events were backfilled
"""
depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
- # Then update the `stream_ordering` position to mark the latest
- # event as the front of the room. This should not be done for
- # backfilled events because backfilled events have negative
- # stream_ordering and happened in the past so we know that we don't
- # need to update the stream_ordering tip/front for the room.
- assert event.internal_metadata.stream_ordering is not None
- if event.internal_metadata.stream_ordering >= 0:
+ if not backfilled:
txn.call_after(
self.store._events_stream_cache.entity_has_changed,
event.room_id,
@@ -1445,12 +1427,7 @@ class PersistEventsStore:
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
def _update_metadata_tables_txn(
- self,
- txn,
- *,
- events_and_contexts,
- all_events_and_contexts,
- inhibit_local_membership_updates: bool = False,
+ self, txn, events_and_contexts, all_events_and_contexts, backfilled
):
"""Update all the miscellaneous tables for new events
@@ -1462,10 +1439,7 @@ class PersistEventsStore:
events that we were going to persist. This includes events
we've already persisted, etc, that wouldn't appear in
events_and_context.
- inhibit_local_membership_updates: Stop the local_current_membership
- from being updated by these events. This should be set to True
- for backfilled events because backfilled events in the past do
- not affect the current local state.
+ backfilled (bool): True if the events were backfilled
"""
# Insert all the push actions into the event_push_actions table.
@@ -1539,7 +1513,7 @@ class PersistEventsStore:
for event, _ in events_and_contexts
if event.type == EventTypes.Member
],
- inhibit_local_membership_updates=inhibit_local_membership_updates,
+ backfilled=backfilled,
)
# Insert event_reference_hashes table.
@@ -1579,13 +1553,11 @@ class PersistEventsStore:
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
- to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
+ to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
def prefill():
for cache_entry in to_prefill:
- self.store._get_event_cache.set(
- (cache_entry.event.event_id,), cache_entry
- )
+ self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
txn.call_after(prefill)
@@ -1666,19 +1638,8 @@ class PersistEventsStore:
txn, table="event_reference_hashes", values=vals
)
- def _store_room_members_txn(
- self, txn, events, *, inhibit_local_membership_updates: bool = False
- ):
- """
- Store a room member in the database.
- Args:
- txn: The transaction to use.
- events: List of events to store.
- inhibit_local_membership_updates: Stop the local_current_membership
- from being updated by these events. This should be set to True
- for backfilled events because backfilled events in the past do
- not affect the current local state.
- """
+ def _store_room_members_txn(self, txn, events, backfilled):
+ """Store a room member in the database."""
def non_null_str_or_none(val: Any) -> Optional[str]:
return val if isinstance(val, str) and "\u0000" not in val else None
@@ -1721,7 +1682,7 @@ class PersistEventsStore:
# band membership", like a remote invite or a rejection of a remote invite.
if (
self.is_mine_id(event.state_key)
- and not inhibit_local_membership_updates
+ and not backfilled
and event.internal_metadata.is_outlier()
and event.internal_metadata.is_out_of_band_membership()
):
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c7b660ac5a..c6bf316d5b 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -15,18 +15,14 @@
import logging
import threading
from typing import (
- TYPE_CHECKING,
- Any,
Collection,
Container,
Dict,
Iterable,
List,
- NoReturn,
Optional,
Set,
Tuple,
- cast,
overload,
)
@@ -42,7 +38,6 @@ from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
- RoomVersion,
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
@@ -61,18 +56,10 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import (
- DatabasePool,
- LoggingDatabaseConnection,
- LoggingTransaction,
-)
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import (
- AbstractStreamIdTracker,
- MultiWriterIdGenerator,
- StreamIdGenerator,
-)
+from synapse.storage.types import Connection
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
@@ -82,13 +69,10 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
logger = logging.getLogger(__name__)
-# These values are used in the `enqueue_event` and `_fetch_loop` methods to
+# These values are used in the `enqueus_event` and `_do_fetch` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
# on jki.re
@@ -105,7 +89,7 @@ event_fetch_ongoing_gauge = Gauge(
@attr.s(slots=True, auto_attribs=True)
-class EventCacheEntry:
+class _EventCacheEntry:
event: EventBase
redacted_event: Optional[EventBase]
@@ -145,7 +129,7 @@ class _EventRow:
json: str
internal_metadata: str
format_version: Optional[int]
- room_version_id: Optional[str]
+ room_version_id: Optional[int]
rejected_reason: Optional[str]
redactions: List[str]
outlier: bool
@@ -169,16 +153,9 @@ class EventsWorkerStore(SQLBaseStore):
# options controlling this.
USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self._stream_id_gen: AbstractStreamIdTracker
- self._backfill_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
@@ -237,7 +214,7 @@ class EventsWorkerStore(SQLBaseStore):
5 * 60 * 1000,
)
- self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
+ self._get_event_cache = LruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
)
@@ -246,21 +223,19 @@ class EventsWorkerStore(SQLBaseStore):
# ID to cache entry. Note that the returned dict may not have the
# requested event in it if the event isn't in the DB.
self._current_event_fetches: Dict[
- str, ObservableDeferred[Dict[str, EventCacheEntry]]
+ str, ObservableDeferred[Dict[str, _EventCacheEntry]]
] = {}
self._event_fetch_lock = threading.Condition()
- self._event_fetch_list: List[
- Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
- ] = []
+ self._event_fetch_list = []
self._event_fetch_ongoing = 0
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
- def get_chain_id_txn(txn: Cursor) -> int:
+ def get_chain_id_txn(txn):
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
- return cast(Tuple[int], txn.fetchone())[0]
+ return txn.fetchone()[0]
self.event_chain_id_gen = build_sequence_generator(
db_conn,
@@ -271,13 +246,7 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id",
)
- def process_replication_rows(
- self,
- stream_name: str,
- instance_name: str,
- token: int,
- rows: Iterable[Any],
- ) -> None:
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
@@ -311,10 +280,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = ...,
- allow_rejected: bool = ...,
- allow_none: Literal[False] = ...,
- check_room_id: Optional[str] = ...,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: Literal[False] = False,
+ check_room_id: Optional[str] = None,
) -> EventBase:
...
@@ -323,10 +292,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
- get_prev_content: bool = ...,
- allow_rejected: bool = ...,
- allow_none: Literal[True] = ...,
- check_room_id: Optional[str] = ...,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: Literal[True] = False,
+ check_room_id: Optional[str] = None,
) -> Optional[EventBase]:
...
@@ -388,7 +357,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_events(
self,
- event_ids: Collection[str],
+ event_ids: Iterable[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
@@ -575,7 +544,7 @@ class EventsWorkerStore(SQLBaseStore):
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
- ) -> Dict[str, EventCacheEntry]:
+ ) -> Dict[str, _EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -609,7 +578,7 @@ class EventsWorkerStore(SQLBaseStore):
# same dict into itself N times).
already_fetching_ids: Set[str] = set()
already_fetching_deferreds: Set[
- ObservableDeferred[Dict[str, EventCacheEntry]]
+ ObservableDeferred[Dict[str, _EventCacheEntry]]
] = set()
for event_id in missing_events_ids:
@@ -632,8 +601,8 @@ class EventsWorkerStore(SQLBaseStore):
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
- Dict[str, EventCacheEntry]
- ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+ Dict[str, _EventCacheEntry]
+ ] = ObservableDeferred(defer.Deferred())
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
@@ -689,12 +658,12 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
- def _invalidate_get_event_cache(self, event_id: str) -> None:
+ def _invalidate_get_event_cache(self, event_id):
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
- ) -> Dict[str, EventCacheEntry]:
+ ) -> Dict[str, _EventCacheEntry]:
"""Fetch events from the caches.
May return rejected events.
@@ -767,123 +736,38 @@ class EventsWorkerStore(SQLBaseStore):
for e in state_to_include.values()
]
- def _maybe_start_fetch_thread(self) -> None:
- """Starts an event fetch thread if we are not yet at the maximum number."""
- with self._event_fetch_lock:
- if (
- self._event_fetch_list
- and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
- ):
- self._event_fetch_ongoing += 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
- # `_event_fetch_ongoing` is decremented in `_fetch_thread`.
- should_start = True
- else:
- should_start = False
-
- if should_start:
- run_as_background_process("fetch_events", self._fetch_thread)
-
- async def _fetch_thread(self) -> None:
- """Services requests for events from `_event_fetch_list`."""
- exc = None
- try:
- await self.db_pool.runWithConnection(self._fetch_loop)
- except BaseException as e:
- exc = e
- raise
- finally:
- should_restart = False
- event_fetches_to_fail = []
- with self._event_fetch_lock:
- self._event_fetch_ongoing -= 1
- event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
-
- # There may still be work remaining in `_event_fetch_list` if we
- # failed, or it was added in between us deciding to exit and
- # decrementing `_event_fetch_ongoing`.
- if self._event_fetch_list:
- if exc is None:
- # We decided to exit, but then some more work was added
- # before `_event_fetch_ongoing` was decremented.
- # If a new event fetch thread was not started, we should
- # restart ourselves since the remaining event fetch threads
- # may take a while to get around to the new work.
- #
- # Unfortunately it is not possible to tell whether a new
- # event fetch thread was started, so we restart
- # unconditionally. If we are unlucky, we will end up with
- # an idle fetch thread, but it will time out after
- # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
- # in any case.
- #
- # Note that multiple fetch threads may run down this path at
- # the same time.
- should_restart = True
- elif isinstance(exc, Exception):
- if self._event_fetch_ongoing == 0:
- # We were the last remaining fetcher and failed.
- # Fail any outstanding fetches since no one else will
- # handle them.
- event_fetches_to_fail = self._event_fetch_list
- self._event_fetch_list = []
- else:
- # We weren't the last remaining fetcher, so another
- # fetcher will pick up the work. This will either happen
- # after their existing work, however long that takes,
- # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
- # they are idle.
- pass
- else:
- # The exception is a `SystemExit`, `KeyboardInterrupt` or
- # `GeneratorExit`. Don't try to do anything clever here.
- pass
-
- if should_restart:
- # We exited cleanly but noticed more work.
- self._maybe_start_fetch_thread()
-
- if event_fetches_to_fail:
- # We were the last remaining fetcher and failed.
- # Fail any outstanding fetches since no one else will handle them.
- assert exc is not None
- with PreserveLoggingContext():
- for _, deferred in event_fetches_to_fail:
- deferred.errback(exc)
-
- def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
+ def _do_fetch(self, conn: Connection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
- i = 0
- while True:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- # There are no requests waiting. If we haven't yet reached the
- # maximum iteration limit, wait for some more requests to turn up.
- # Otherwise, bail out.
- single_threaded = self.database_engine.single_threaded
- if (
- not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
- or single_threaded
- or i > EVENT_QUEUE_ITERATIONS
- ):
- return
-
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
+ try:
+ i = 0
+ while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ single_threaded = self.database_engine.single_threaded
+ if (
+ not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+ or single_threaded
+ or i > EVENT_QUEUE_ITERATIONS
+ ):
+ break
+ else:
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
- self._fetch_event_list(conn, event_list)
+ self._fetch_event_list(conn, event_list)
+ finally:
+ self._event_fetch_ongoing -= 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
def _fetch_event_list(
- self,
- conn: LoggingDatabaseConnection,
- event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
+ self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
) -> None:
"""Handle a load of requests from the _event_fetch_list queue
@@ -910,7 +794,7 @@ class EventsWorkerStore(SQLBaseStore):
)
# We only want to resolve deferreds from the main thread
- def fire() -> None:
+ def fire():
for _, d in event_list:
d.callback(row_dict)
@@ -920,16 +804,18 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
- def fire_errback(exc: Exception) -> None:
- for _, d in event_list:
- d.errback(exc)
+ def fire(evs, exc):
+ for _, d in evs:
+ if not d.called:
+ with PreserveLoggingContext():
+ d.errback(exc)
with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire_errback, e)
+ self.hs.get_reactor().callFromThread(fire, event_list, e)
async def _get_events_from_db(
- self, event_ids: Collection[str]
- ) -> Dict[str, EventCacheEntry]:
+ self, event_ids: Iterable[str]
+ ) -> Dict[str, _EventCacheEntry]:
"""Fetch a bunch of events from the database.
May return rejected events.
@@ -945,29 +831,29 @@ class EventsWorkerStore(SQLBaseStore):
map from event id to result. May return extra events which
weren't asked for.
"""
- fetched_event_ids: Set[str] = set()
- fetched_events: Dict[str, _EventRow] = {}
+ fetched_events = {}
events_to_fetch = event_ids
while events_to_fetch:
row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
- redaction_ids: Set[str] = set()
+ redaction_ids = set()
for event_id in events_to_fetch:
row = row_map.get(event_id)
- fetched_event_ids.add(event_id)
+ fetched_events[event_id] = row
if row:
- fetched_events[event_id] = row
redaction_ids.update(row.redactions)
- events_to_fetch = redaction_ids.difference(fetched_event_ids)
+ events_to_fetch = redaction_ids.difference(fetched_events.keys())
if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch)
# build a map from event_id to EventBase
- event_map: Dict[str, EventBase] = {}
+ event_map = {}
for event_id, row in fetched_events.items():
+ if not row:
+ continue
assert row.event_id == event_id
rejected_reason = row.rejected_reason
@@ -995,7 +881,6 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row.room_version_id
- room_version: Optional[RoomVersion]
if not room_version_id:
# this should only happen for out-of-band membership events which
# arrived before #6983 landed. For all other events, we should have
@@ -1066,14 +951,14 @@ class EventsWorkerStore(SQLBaseStore):
# finally, we can decide whether each one needs redacting, and build
# the cache entries.
- result_map: Dict[str, EventCacheEntry] = {}
+ result_map = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
- cache_entry = EventCacheEntry(
+ cache_entry = _EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)
@@ -1082,7 +967,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
- async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
+ async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -1095,12 +980,23 @@ class EventsWorkerStore(SQLBaseStore):
that weren't requested.
"""
- events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
+ events_d = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append((events, events_d))
+
self._event_fetch_lock.notify()
- self._maybe_start_fetch_thread()
+ if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
+ self._event_fetch_ongoing += 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+ should_start = True
+ else:
+ should_start = False
+
+ if should_start:
+ run_as_background_process(
+ "fetch_events", self.db_pool.runWithConnection, self._do_fetch
+ )
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
@@ -1250,7 +1146,7 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
- async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
+ async def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
@@ -1279,7 +1175,7 @@ class EventsWorkerStore(SQLBaseStore):
event_ids: events we are looking for
Returns:
- The set of events we have already seen.
+ set[str]: The events we have already seen.
"""
res = await self._have_seen_events_dict(
(room_id, event_id) for event_id in event_ids
@@ -1302,9 +1198,7 @@ class EventsWorkerStore(SQLBaseStore):
}
results = {x: True for x in cache_results}
- def have_seen_events_txn(
- txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
- ) -> None:
+ def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@@ -1330,14 +1224,12 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
- async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
+ async def have_seen_event(self, room_id: str, event_id: str):
# this only exists for the benefit of the @cachedList descriptor on
# _have_seen_events_dict
raise NotImplementedError()
- def _get_current_state_event_counts_txn(
- self, txn: LoggingTransaction, room_id: str
- ) -> int:
+ def _get_current_state_event_counts_txn(self, txn, room_id):
"""
See get_current_state_event_counts.
"""
@@ -1362,7 +1254,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
- async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
+ async def get_room_complexity(self, room_id):
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@@ -1370,10 +1262,10 @@ class EventsWorkerStore(SQLBaseStore):
more resources.
Args:
- room_id: The room ID to query.
+ room_id (str)
Returns:
- dict[str:float] of complexity version to complexity.
+ dict[str:int] of complexity version to complexity.
"""
state_events = await self.get_current_state_event_counts(room_id)
@@ -1383,13 +1275,13 @@ class EventsWorkerStore(SQLBaseStore):
return {"v1": complexity_v1}
- def get_current_events_token(self) -> int:
+ def get_current_events_token(self):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ ) -> List[Tuple]:
"""Returns new events, for the Events replication stream
Args:
@@ -1403,15 +1295,13 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_all_new_forward_event_rows(
- txn: LoggingTransaction,
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events AS se USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
@@ -1421,9 +1311,7 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, instance_name, limit))
- return cast(
- List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
- )
+ return txn.fetchall()
return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
@@ -1431,7 +1319,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_ex_outlier_stream_rows(
self, instance_name: str, last_id: int, current_id: int
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ ) -> List[Tuple]:
"""Returns de-outliered events, for the Events replication stream
Args:
@@ -1444,16 +1332,14 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
- def get_ex_outlier_stream_rows_txn(
- txn: LoggingTransaction,
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ def get_ex_outlier_stream_rows_txn(txn):
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events AS se USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
@@ -1464,9 +1350,7 @@ class EventsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (last_id, current_id, instance_name))
- return cast(
- List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
- )
+ return txn.fetchall()
return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@@ -1474,7 +1358,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_new_backfill_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
+ ) -> Tuple[List[Tuple[int, list]], int, bool]:
"""Get updates for backfill replication stream, including all new
backfilled events and events that have gone from being outliers to not.
@@ -1502,15 +1386,13 @@ class EventsWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_new_backfill_event_rows(
- txn: LoggingTransaction,
- ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
+ def get_all_new_backfill_event_rows(txn):
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " se.state_key, redacts, relates_to_id"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events AS se USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" AND instance_name = ?"
@@ -1518,15 +1400,7 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
- new_event_updates: List[
- Tuple[int, Tuple[str, str, str, str, str, str]]
- ] = []
- row: Tuple[int, str, str, str, str, str, str]
- # Type safety: iterating over `txn` yields `Tuple`, i.e.
- # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
- # variadic tuple to a fixed length tuple and flags it up as an error.
- for row in txn: # type: ignore[assignment]
- new_event_updates.append((row[0], row[1:]))
+ new_event_updates = [(row[0], row[1:]) for row in txn]
limited = False
if len(new_event_updates) == limit:
@@ -1537,11 +1411,11 @@ class EventsWorkerStore(SQLBaseStore):
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
- " se.state_key, redacts, relates_to_id"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events AS se USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
@@ -1549,11 +1423,7 @@ class EventsWorkerStore(SQLBaseStore):
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound, instance_name))
- # Type safety: iterating over `txn` yields `Tuple`, i.e.
- # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
- # variadic tuple to a fixed length tuple and flags it up as an error.
- for row in txn: # type: ignore[assignment]
- new_event_updates.append((row[0], row[1:]))
+ new_event_updates.extend((row[0], row[1:]) for row in txn)
if len(new_event_updates) >= limit:
upper_bound = new_event_updates[-1][0]
@@ -1567,7 +1437,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_updated_current_state_deltas(
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
- ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
+ ) -> Tuple[List[Tuple], int, bool]:
"""Fetch updates from current_state_delta_stream
Args:
@@ -1587,9 +1457,7 @@ class EventsWorkerStore(SQLBaseStore):
* `limited` is whether there are more updates to fetch.
"""
- def get_all_updated_current_state_deltas_txn(
- txn: LoggingTransaction,
- ) -> List[Tuple[int, str, str, str, str]]:
+ def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
@@ -1598,23 +1466,21 @@ class EventsWorkerStore(SQLBaseStore):
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
- return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
+ return txn.fetchall()
- def get_deltas_for_stream_id_txn(
- txn: LoggingTransaction, stream_id: int
- ) -> List[Tuple[int, str, str, str, str]]:
+ def get_deltas_for_stream_id_txn(txn, stream_id):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
- return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
+ return txn.fetchall()
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
- rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
+ rows: List[Tuple] = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
@@ -1643,14 +1509,14 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
- async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
+ async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cached(max_entries=5000)
- async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
+ async def get_event_ordering(self, event_id):
res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
@@ -1673,9 +1539,7 @@ class EventsWorkerStore(SQLBaseStore):
None otherwise.
"""
- def get_next_event_to_expire_txn(
- txn: LoggingTransaction,
- ) -> Optional[Tuple[str, int]]:
+ def get_next_event_to_expire_txn(txn):
txn.execute(
"""
SELECT event_id, expiry_ts FROM event_expiry
@@ -1683,7 +1547,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
)
- return cast(Optional[Tuple[str, int]], txn.fetchone())
+ return txn.fetchone()
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@@ -1747,10 +1611,10 @@ class EventsWorkerStore(SQLBaseStore):
return mapping
@wrap_as_background_process("_cleanup_old_transaction_ids")
- async def _cleanup_old_transaction_ids(self) -> None:
+ async def _cleanup_old_transaction_ids(self):
"""Cleans out transaction id mappings older than 24hrs."""
- def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
+ def _cleanup_old_transaction_ids_txn(txn):
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?
@@ -1762,198 +1626,3 @@ class EventsWorkerStore(SQLBaseStore):
"_cleanup_old_transaction_ids",
_cleanup_old_transaction_ids_txn,
)
-
- async def is_event_next_to_backward_gap(self, event: EventBase) -> bool:
- """Check if the given event is next to a backward gap of missing events.
- <latest messages> A(False)--->B(False)--->C(True)---> <gap, unknown events> <oldest messages>
-
- Args:
- room_id: room where the event lives
- event_id: event to check
-
- Returns:
- Boolean indicating whether it's an extremity
- """
-
- def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool:
- # If the event in question has any of its prev_events listed as a
- # backward extremity, it's next to a gap.
- #
- # We can't just check the backward edges in `event_edges` because
- # when we persist events, we will also record the prev_events as
- # edges to the event in question regardless of whether we have those
- # prev_events yet. We need to check whether those prev_events are
- # backward extremities, also known as gaps, that need to be
- # backfilled.
- backward_extremity_query = """
- SELECT 1 FROM event_backward_extremities
- WHERE
- room_id = ?
- AND %s
- LIMIT 1
- """
-
- # If the event in question is a backward extremity or has any of its
- # prev_events listed as a backward extremity, it's next to a
- # backward gap.
- clause, args = make_in_list_sql_clause(
- self.database_engine,
- "event_id",
- [event.event_id] + list(event.prev_event_ids()),
- )
-
- txn.execute(backward_extremity_query % (clause,), [event.room_id] + args)
- backward_extremities = txn.fetchall()
-
- # We consider any backward extremity as a backward gap
- if len(backward_extremities):
- return True
-
- return False
-
- return await self.db_pool.runInteraction(
- "is_event_next_to_backward_gap_txn",
- is_event_next_to_backward_gap_txn,
- )
-
- async def is_event_next_to_forward_gap(self, event: EventBase) -> bool:
- """Check if the given event is next to a forward gap of missing events.
- The gap in front of the latest events is not considered a gap.
- <latest messages> A(False)--->B(False)--->C(False)---> <gap, unknown events> <oldest messages>
- <latest messages> A(False)--->B(False)---> <gap, unknown events> --->D(True)--->E(False) <oldest messages>
-
- Args:
- room_id: room where the event lives
- event_id: event to check
-
- Returns:
- Boolean indicating whether it's an extremity
- """
-
- def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool:
- # If the event in question is a forward extremity, we will just
- # consider any potential forward gap as not a gap since it's one of
- # the latest events in the room.
- #
- # `event_forward_extremities` does not include backfilled or outlier
- # events so we can't rely on it to find forward gaps. We can only
- # use it to determine whether a message is the latest in the room.
- #
- # We can't combine this query with the `forward_edge_query` below
- # because if the event in question has no forward edges (isn't
- # referenced by any other event's prev_events) but is in
- # `event_forward_extremities`, we don't want to return 0 rows and
- # say it's next to a gap.
- forward_extremity_query = """
- SELECT 1 FROM event_forward_extremities
- WHERE
- room_id = ?
- AND event_id = ?
- LIMIT 1
- """
-
- # Check to see whether the event in question is already referenced
- # by another event. If we don't see any edges, we're next to a
- # forward gap.
- forward_edge_query = """
- SELECT 1 FROM event_edges
- /* Check to make sure the event referencing our event in question is not rejected */
- LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
- WHERE
- event_edges.room_id = ?
- AND event_edges.prev_event_id = ?
- /* It's not a valid edge if the event referencing our event in
- * question is rejected.
- */
- AND rejections.event_id IS NULL
- LIMIT 1
- """
-
- # We consider any forward extremity as the latest in the room and
- # not a forward gap.
- #
- # To expand, even though there is technically a gap at the front of
- # the room where the forward extremities are, we consider those the
- # latest messages in the room so asking other homeservers for more
- # is useless. The new latest messages will just be federated as
- # usual.
- txn.execute(forward_extremity_query, (event.room_id, event.event_id))
- forward_extremities = txn.fetchall()
- if len(forward_extremities):
- return False
-
- # If there are no forward edges to the event in question (another
- # event hasn't referenced this event in their prev_events), then we
- # assume there is a forward gap in the history.
- txn.execute(forward_edge_query, (event.room_id, event.event_id))
- forward_edges = txn.fetchall()
- if not len(forward_edges):
- return True
-
- return False
-
- return await self.db_pool.runInteraction(
- "is_event_next_to_gap_txn",
- is_event_next_to_gap_txn,
- )
-
- async def get_event_id_for_timestamp(
- self, room_id: str, timestamp: int, direction: str
- ) -> Optional[str]:
- """Find the closest event to the given timestamp in the given direction.
-
- Args:
- room_id: Room to fetch the event from
- timestamp: The point in time (inclusive) we should navigate from in
- the given direction to find the closest event.
- direction: ["f"|"b"] to indicate whether we should navigate forward
- or backward from the given timestamp to find the closest event.
-
- Returns:
- The closest event_id otherwise None if we can't find any event in
- the given direction.
- """
-
- sql_template = """
- SELECT event_id FROM events
- LEFT JOIN rejections USING (event_id)
- WHERE
- origin_server_ts %s ?
- AND room_id = ?
- /* Make sure event is not rejected */
- AND rejections.event_id IS NULL
- ORDER BY origin_server_ts %s
- LIMIT 1;
- """
-
- def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
- if direction == "b":
- # Find closest event *before* a given timestamp. We use descending
- # (which gives values largest to smallest) because we want the
- # largest possible timestamp *before* the given timestamp.
- comparison_operator = "<="
- order = "DESC"
- else:
- # Find closest event *after* a given timestamp. We use ascending
- # (which gives values smallest to largest) because we want the
- # closest possible timestamp *after* the given timestamp.
- comparison_operator = ">="
- order = "ASC"
-
- txn.execute(
- sql_template % (comparison_operator, order), (timestamp, room_id)
- )
- row = txn.fetchone()
- if row:
- (event_id,) = row
- return event_id
-
- return None
-
- if direction not in ("f", "b"):
- raise ValueError("Unknown direction: %s" % (direction,))
-
- return await self.db_pool.runInteraction(
- "get_event_id_for_timestamp_txn",
- get_event_id_for_timestamp_txn,
- )
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 91b0576b85..3eb30944bf 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -118,7 +118,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
logger.info("[purge] looking for events to delete")
- should_delete_expr = "state_events.state_key IS NULL"
+ should_delete_expr = "state_key IS NULL"
should_delete_params: Tuple[Any, ...] = ()
if not delete_local_events:
should_delete_expr += " AND event_id NOT LIKE ?"
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 3b63267395..fa782023d4 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -28,10 +28,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import (
- AbstractStreamIdTracker,
- StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -85,9 +82,9 @@ class PushRulesWorkerStore(
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
- db_conn, "push_rules_stream", "stream_id"
- )
+ self._push_rules_stream_id_gen: Union[
+ StreamIdGenerator, SlavedIdTracker
+ ] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e1ddf06916..0e8c168667 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -106,15 +106,6 @@ class RefreshTokenLookupResult:
has_next_access_token_been_used: bool
"""True if the next access token was already used at least once."""
- expiry_ts: Optional[int]
- """The time at which the refresh token expires and can not be used.
- If None, the refresh token doesn't expire."""
-
- ultimate_session_expiry_ts: Optional[int]
- """The time at which the session comes to an end and can no longer be
- refreshed.
- If None, the session can be refreshed indefinitely."""
-
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
@@ -1635,10 +1626,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
rt.user_id,
rt.device_id,
rt.next_token_id,
- (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed,
- at.used AS has_next_access_token_been_used,
- rt.expiry_ts,
- rt.ultimate_session_expiry_ts
+ (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
+ at.used has_next_access_token_been_used
FROM refresh_tokens rt
LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
@@ -1659,8 +1648,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
has_next_refresh_token_been_refreshed=row[4],
# This column is nullable, ensure it's a boolean
has_next_access_token_been_used=(row[5] or False),
- expiry_ts=row[6],
- ultimate_session_expiry_ts=row[7],
)
return await self.db_pool.runInteraction(
@@ -1928,8 +1915,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
user_id: str,
token: str,
device_id: Optional[str],
- expiry_ts: Optional[int],
- ultimate_session_expiry_ts: Optional[int],
) -> int:
"""Adds a refresh token for the given user.
@@ -1937,13 +1922,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
user_id: The user ID.
token: The new access token to add.
device_id: ID of the device to associate with the refresh token.
- expiry_ts (milliseconds since the epoch): Time after which the
- refresh token cannot be used.
- If None, the refresh token never expires until it has been used.
- ultimate_session_expiry_ts (milliseconds since the epoch):
- Time at which the session will end and can not be extended any
- further.
- If None, the session can be refreshed indefinitely.
Raises:
StoreError if there was a problem adding this.
Returns:
@@ -1959,8 +1937,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"device_id": device_id,
"token": token,
"next_token_id": None,
- "expiry_ts": expiry_ts,
- "ultimate_session_expiry_ts": ultimate_session_expiry_ts,
},
desc="add_refresh_token_to_user",
)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 6b2a8d06a6..033a9831d6 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -476,7 +476,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
- AND c.state_key = ?
+ AND state_key = ?
AND c.membership = ?
"""
else:
@@ -487,7 +487,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
- AND c.state_key = ?
+ AND state_key = ?
AND m.membership = ?
"""
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 57aab55259..42dc807d17 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -497,7 +497,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
oldest `limit` events.
Returns:
- The list of events (in ascending stream order) and the token from the start
+ The list of events (in ascending order) and the token from the start
of the chunk of events returned.
"""
if from_key == to_key:
@@ -510,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if not has_changed:
return [], from_key
- def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
+ def f(txn):
# To handle tokens with a non-empty instance_map we fetch more
# results than necessary and then filter down
min_from_id = from_key.stream
@@ -565,13 +565,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
- """Fetch membership events for a given user.
-
- All such events whose stream ordering `s` lies in the range
- `from_key < s <= to_key` are returned. Events are ordered by ascending stream
- order.
- """
- # Start by ruling out cases where a DB query is not necessary.
if from_key == to_key:
return []
@@ -582,7 +575,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if not has_changed:
return []
- def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
+ def f(txn):
# To handle tokens with a non-empty instance_map we fetch more
# results than necessary and then filter down
min_from_id = from_key.stream
@@ -641,7 +634,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
Returns:
A list of events and a token pointing to the start of the returned
- events. The events returned are in ascending topological order.
+ events. The events returned are in ascending order.
"""
rows, token = await self.get_recent_event_ids_for_room(
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 1622822552..d7dc1f73ac 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,6 @@
import logging
from collections import namedtuple
-from enum import Enum
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
import attr
@@ -45,16 +44,6 @@ _UpdateTransactionRow = namedtuple(
)
-class DestinationSortOrder(Enum):
- """Enum to define the sorting method used when returning destinations."""
-
- DESTINATION = "destination"
- RETRY_LAST_TS = "retry_last_ts"
- RETTRY_INTERVAL = "retry_interval"
- FAILURE_TS = "failure_ts"
- LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering"
-
-
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DestinationRetryTimings:
"""The current destination retry timing info for a remote server."""
@@ -491,62 +480,3 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
destinations = [row[0] for row in txn]
return destinations
-
- async def get_destinations_paginate(
- self,
- start: int,
- limit: int,
- destination: Optional[str] = None,
- order_by: str = DestinationSortOrder.DESTINATION.value,
- direction: str = "f",
- ) -> Tuple[List[JsonDict], int]:
- """Function to retrieve a paginated list of destinations.
- This will return a json list of destinations and the
- total number of destinations matching the filter criteria.
-
- Args:
- start: start number to begin the query from
- limit: number of rows to retrieve
- destination: search string in destination
- order_by: the sort order of the returned list
- direction: sort ascending or descending
- Returns:
- A tuple of a list of mappings from destination to information
- and a count of total destinations.
- """
-
- def get_destinations_paginate_txn(
- txn: LoggingTransaction,
- ) -> Tuple[List[JsonDict], int]:
- order_by_column = DestinationSortOrder(order_by).value
-
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
-
- args = []
- where_statement = ""
- if destination:
- args.extend(["%" + destination.lower() + "%"])
- where_statement = "WHERE LOWER(destination) LIKE ?"
-
- sql_base = f"FROM destinations {where_statement} "
- sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
- txn.execute(sql, args)
- count = txn.fetchone()[0]
-
- sql = f"""
- SELECT destination, retry_last_ts, retry_interval, failure_ts,
- last_successful_stream_ordering
- {sql_base}
- ORDER BY {order_by_column} {order}, destination ASC
- LIMIT ? OFFSET ?
- """
- txn.execute(sql, args + [limit, start])
- destinations = self.db_pool.cursor_to_dict(txn)
- return destinations, count
-
- return await self.db_pool.runInteraction(
- "get_destinations_paginate_txn", get_destinations_paginate_txn
- )
|